Pytorch中的modle.train,model.eval,with torch.no_grad解讀
modle.train,model.eval,with torch.no_grad解讀
1. 最近在學習pytorch過程中遇到了幾個問題
不理解為什么在訓練和測試函數(shù)中model.eval(),和model.train()的區(qū)別,經(jīng)查閱后做如下整理
一般情況下,我們訓練過程如下:
拿到數(shù)據(jù)后進行訓練,在訓練過程中,使用
model.train():告訴我們的網(wǎng)絡,這個階段是用來訓練的,可以更新參數(shù)。
訓練完成后進行預測,在預測過程中,使用
model.eval(): 告訴我們的網(wǎng)絡,這個階段是用來測試的,于是模型的參數(shù)在該階段不進行更新。
2. 但是為什么在eval()階段會使用with torch.no_grad()?
查閱相關(guān)資料:傳送門
with torch.no_grad - disables tracking of gradients in autograd.
model.eval() changes the forward() behaviour of the module it is called upon
eg, it disables dropout and has batch norm use the entire population statistics
總結(jié)一下就是說,在eval階段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都會進行預測,而使用no_grad則設置讓梯度Autograd設置為False(因為在訓練中我們默認是True),這樣保證了反向過程為純粹的測試,而不變參數(shù)。
另外,參考文檔說這樣避免每一個參數(shù)都要設置,解放了GPU底層的時間開銷,在測試階段統(tǒng)一梯度設置為False
model.eval()與torch.no_grad()的作用
model.eval()
經(jīng)常在模型推理代碼的前面, 都會添加model.eval(), 主要有3個作用:
- 1.不進行dropout
- 2.不更新batchnorm的mean 和var 參數(shù)
- 3.不進行梯度反向傳播, 但梯度仍然會計算
torch.no_grad()
torch.no_grad的一般使用方法是, 在代碼塊外面用with torch.no_grad()給包起來。 如下面這樣:
with torch.no_grad(): ?? ?# your code?
它的主要作用有2個:
- 1.不進行梯度的計算(當然也就沒辦法反向傳播了), 節(jié)約顯存和算力
- 2.dropout和batchnorn還是會正常更新
異同
從上面的介紹中可以非常明確的看出,它們的相同點是一般都用在推理階段, 但它們的作用是完全不同的, 也沒有重疊。 可以一起使用。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
np.mean()和np.std()函數(shù)的具體使用
本文主要介紹了np.mean()和np.std()函數(shù)的具體使用,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-03-03
python中matplotlib實現(xiàn)最小二乘法擬合的過程詳解
這篇文章主要給大家介紹了關(guān)于python中matplotlib實現(xiàn)最小二乘法擬合的相關(guān)資料,文中通過示例代碼詳細介紹了關(guān)于最小二乘法擬合直線和最小二乘法擬合曲線的實現(xiàn)過程,需要的朋友可以參考借鑒,下面來一起看看吧。2017-07-07
Python函數(shù)isalnum用法示例小結(jié)
isalnum()函數(shù)是Python中的一個內(nèi)置函數(shù),用于判斷字符串是否只由數(shù)字和字母組成,其內(nèi)部實現(xiàn)原理比較簡單,只需遍歷字符串中的每一個字符即可,這篇文章主要介紹了Python函數(shù)isalnum用法介紹,需要的朋友可以參考下2024-01-01
Python內(nèi)置函數(shù)ord()的實現(xiàn)示例
ord()函數(shù)是用于返回字符的Unicode碼點,適用于處理文本和國際化應用,它只能處理單個字符,超過一字符或非字符串類型會引發(fā)TypeError,示例代碼展示了如何使用ord()進行字符轉(zhuǎn)換和比較2024-09-09
Python的mysql數(shù)據(jù)庫的更新如何實現(xiàn)
這篇文章主要介紹了Python的mysql數(shù)據(jù)庫的更新如何實現(xiàn)的相關(guān)資料,這里提供實例代碼,幫助大家理解應用這部分知識,需要的朋友可以參考下2017-07-07

