pytorch中model.train()和model.eval()用法及說明
model.train()和model.eval()用法
1.1 model.train()
model.train()的作用是啟用 Batch Normalization 和 Dropout。
如果模型中有BN層(Batch Normalization)和Dropout,需要在訓(xùn)練時添加model.train()。
model.train()是保證BN層能夠用到每一批數(shù)據(jù)的均值和方差。
對于Dropout,model.train()是隨機(jī)取一部分網(wǎng)絡(luò)連接來訓(xùn)練更新參數(shù)。
1.2 model.eval()
model.eval()的作用是不啟用 Batch Normalization 和 Dropout。
如果模型中有BN層(Batch Normalization)和Dropout,在測試時添加model.eval()。
model.eval()是保證BN層能夠用全部訓(xùn)練數(shù)據(jù)的均值和方差,即測試過程中要保證BN層的均值和方差不變。
對于Dropout,model.eval()是利用到了所有網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
訓(xùn)練完train樣本后,生成的模型model要用來測試樣本。
在model(test)之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會改變權(quán)值。這是model中含有BN層和Dropout所帶來的的性質(zhì)。
在做one classification的時候,訓(xùn)練集和測試集的樣本分布是不一樣的,尤其需要注意這一點(diǎn)。
1.3 分析原因
使用PyTorch進(jìn)行訓(xùn)練和測試時一定注意要把實(shí)例化的model指定train/eval。
model.eval()時,框架會自動把BN和Dropout固定住,不會取平均,而是用訓(xùn)練好的值,
不然的話,一旦test的batch_size過小,很容易就會被BN層導(dǎo)致生成圖片顏色失真極大?。。。。?!
總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python全面解析json數(shù)據(jù)并保存為csv文件
這篇文章主要介紹了Python全面解析json數(shù)據(jù)并保存為csv文件,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07Python pandas 的索引方式 data.loc[],data[][]示例詳解
這篇文章主要介紹了Python pandas 的索引方式 data.loc[], data[][]的相關(guān)資料,其中data.loc[index,column]使用.loc[ ]第一個參數(shù)是行索引,第二個參數(shù)是列索引,本文結(jié)合實(shí)例代碼講解的非常詳細(xì),需要的朋友可以參考下2023-02-02Python實(shí)現(xiàn)刪除list列表重復(fù)元素的方法總結(jié)
在Python編程中,我們經(jīng)常需要處理列表中的重復(fù)元素,這篇文章為大家介紹了五種高效的方法來刪除列表中的重復(fù)元素,希望對大家有所幫助2023-07-07Python列表排序方法reverse、sort、sorted詳解
這篇文章主要介紹了Python列表排序方法reverse、sort、sorted詳解,需要的朋友可以參考下2021-04-04