踩坑:pytorch中eval模式下結(jié)果遠(yuǎn)差于train模式介紹
首先,eval模式和train模式得到不同的結(jié)果是正常的。我的模型中,eval模式和train模式不同之處在于Batch Normalization和Dropout。Dropout比較簡單,在train時(shí)會丟棄一部分連接,在eval時(shí)則不會。Batch Normalization,在train時(shí)不僅使用了當(dāng)前batch的均值和方差,也使用了歷史batch統(tǒng)計(jì)上的均值和方差,并做一個(gè)加權(quán)平均(momentum參數(shù))。在test時(shí),由于此時(shí)batchsize不一定一致,因此不再使用當(dāng)前batch的均值和方差,僅使用歷史訓(xùn)練時(shí)的統(tǒng)計(jì)值。
我出bug的現(xiàn)象是,train模式下可以收斂,但一旦在測試中切換到了eval模式,結(jié)果就很差。如果在測試中仍沿用train模式,反而可以得到不錯(cuò)的結(jié)果。為了確保是程序bug而不是算法本身就不適合于預(yù)測,我在測試時(shí)再次使用了訓(xùn)練集,正常情況下此時(shí)應(yīng)發(fā)生過擬合,正確率一定會很高,然而eval模式下正確率仍然很低。參照網(wǎng)上的一些說法(Performance highly degraded when eval() is activated in the test phase
),我調(diào)大了batchsize,降低了BN層的momentum,檢查了是否存在不同層使用相同BN層的bug,均不見效。有一種方法說應(yīng)在BN層設(shè)置track_running_stats為False,它雖然帶來了好的效果,但實(shí)際上它只不過是不用eval模式,切回train模式罷了,所以也不對。
學(xué)習(xí)了在訓(xùn)練過程中,如何將BN層中統(tǒng)計(jì)的均值和方差輸出。即在forward()中,
# bn是一個(gè)BN層,torch.nn.batch_normalization(...) print(bn.running_mean) print(bn.running_var)
同時(shí)學(xué)習(xí)了如何輸出一個(gè)Tensor自身的均值和方差,即
# x是一個(gè)Tensor,dims是需要計(jì)算的維度 print(x.cpu().detach().numpy().mean(dims) print(x.cpu().detach().numpy().var(dims)
觀察每一層的輸出結(jié)果,發(fā)現(xiàn)出現(xiàn)了很大的方差,才猛然意識到自己的輸入數(shù)據(jù)沒有做歸一化(事后想想也確實(shí)如此,畢竟模型和訓(xùn)練方法都是github上參考別人的,出錯(cuò)概率很??;反而是自己寫的DataSet部分,其實(shí)是最容易出錯(cuò)的)。給模型加上歸一化后,eval和train的結(jié)果就沒有問題了。
再次驗(yàn)證了我的觀點(diǎn):越是玄學(xué)的問題,越是傻逼的bug。
補(bǔ)充知識:Pytorch中的train和eval用法注意點(diǎn)
1.介紹
一般情況,model.train()是在訓(xùn)練的時(shí)候用到,model.eval()是在測試的時(shí)候用到
2.用法
如果模型中沒有類似于BN這樣的歸一化或者Dropout,model.train()和model.eval()可以不要(建議寫一下,比較安全),并且model.train()和model.eval()得到的效果是一樣
如果模型中有類似于BN這樣的歸一化或者Dropout,并且程序需要邊訓(xùn)練和邊測試,最好就是用model.eval()測試完之后,后面補(bǔ)一個(gè)model.train()。
其中model.train()是保證BN用每一批數(shù)據(jù)的均值和方差,而model.eval()是保證BN用全部訓(xùn)練數(shù)據(jù)的均值和方差;而對于Dropout,model.train()是隨機(jī)取一部分網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù),而model.eval()是利用到了所有網(wǎng)絡(luò)連接(結(jié)果是取了平均)
以上這篇踩坑:pytorch中eval模式下結(jié)果遠(yuǎn)差于train模式介紹就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)批量識別圖片文字并存為Excel
批量文字識別是Python辦公自動化的基本操作,應(yīng)用在我們工作生活中的方方面面。本文主要以開源免費(fèi)的easyocr來實(shí)現(xiàn)批量識別圖片文字并存為Excel,感興趣的可以學(xué)習(xí)一下2022-06-06Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作
今天小編就為大家分享一篇Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01python反反爬蟲技術(shù)限制連續(xù)請求時(shí)間處理
這篇文章主要為大家介紹了python反反爬蟲技術(shù)限制連續(xù)請求時(shí)間處理,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-06-06Python Json數(shù)據(jù)文件操作原理解析
這篇文章主要介紹了Python Json數(shù)據(jù)文件操作原理解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05如何用python獲取EXCEL文件內(nèi)容并保存到DBC
很多時(shí)候,使用python進(jìn)行數(shù)據(jù)分析的第一步就是讀取excel文件,下面這篇文章主要給大家介紹了關(guān)于如何用python獲取EXCEL文件內(nèi)容并保存到DBC的相關(guān)資料,需要的朋友可以參考2023-12-12如何使用PyCharm將代碼上傳到GitHub上(圖文詳解)
這篇文章主要介紹了如何使用PyCharm將代碼上傳到GitHub上(圖文詳解),文中通過圖文介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04