Pytorch中的modle.train,model.eval,with torch.no_grad解讀
modle.train,model.eval,with torch.no_grad解讀
1. 最近在學(xué)習(xí)pytorch過程中遇到了幾個(gè)問題
不理解為什么在訓(xùn)練和測試函數(shù)中model.eval(),和model.train()的區(qū)別,經(jīng)查閱后做如下整理
一般情況下,我們訓(xùn)練過程如下:
拿到數(shù)據(jù)后進(jìn)行訓(xùn)練,在訓(xùn)練過程中,使用
model.train()
:告訴我們的網(wǎng)絡(luò),這個(gè)階段是用來訓(xùn)練的,可以更新參數(shù)。
訓(xùn)練完成后進(jìn)行預(yù)測,在預(yù)測過程中,使用
model.eval()
: 告訴我們的網(wǎng)絡(luò),這個(gè)階段是用來測試的,于是模型的參數(shù)在該階段不進(jìn)行更新。
2. 但是為什么在eval()階段會(huì)使用with torch.no_grad()?
查閱相關(guān)資料:傳送門
with torch.no_grad - disables tracking of gradients in autograd.
model.eval() changes the forward() behaviour of the module it is called upon
eg, it disables dropout and has batch norm use the entire population statistics
總結(jié)一下就是說,在eval階段了,即使不更新,但是在模型中所使用的dropout或者batch norm也就失效了,直接都會(huì)進(jìn)行預(yù)測,而使用no_grad則設(shè)置讓梯度Autograd設(shè)置為False(因?yàn)樵谟?xùn)練中我們默認(rèn)是True),這樣保證了反向過程為純粹的測試,而不變參數(shù)。
另外,參考文檔說這樣避免每一個(gè)參數(shù)都要設(shè)置,解放了GPU底層的時(shí)間開銷,在測試階段統(tǒng)一梯度設(shè)置為False
model.eval()與torch.no_grad()的作用
model.eval()
經(jīng)常在模型推理代碼的前面, 都會(huì)添加model.eval(), 主要有3個(gè)作用:
- 1.不進(jìn)行dropout
- 2.不更新batchnorm的mean 和var 參數(shù)
- 3.不進(jìn)行梯度反向傳播, 但梯度仍然會(huì)計(jì)算
torch.no_grad()
torch.no_grad的一般使用方法是, 在代碼塊外面用with torch.no_grad()給包起來。 如下面這樣:
with torch.no_grad(): ?? ?# your code?
它的主要作用有2個(gè):
- 1.不進(jìn)行梯度的計(jì)算(當(dāng)然也就沒辦法反向傳播了), 節(jié)約顯存和算力
- 2.dropout和batchnorn還是會(huì)正常更新
異同
從上面的介紹中可以非常明確的看出,它們的相同點(diǎn)是一般都用在推理階段, 但它們的作用是完全不同的, 也沒有重疊。 可以一起使用。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
深度解讀Python如何實(shí)現(xiàn)dbscan算法
DBScan?是密度基于空間聚類,它是一種基于密度的聚類算法,其與其他聚類算法(如K-Means)不同的是,它不需要事先知道簇的數(shù)量。本文就來帶大家了解一下Python是如何實(shí)現(xiàn)dbscan算法,感興趣的可以了解一下2023-02-02np.mean()和np.std()函數(shù)的具體使用
本文主要介紹了np.mean()和np.std()函數(shù)的具體使用,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03python中matplotlib實(shí)現(xiàn)最小二乘法擬合的過程詳解
這篇文章主要給大家介紹了關(guān)于python中matplotlib實(shí)現(xiàn)最小二乘法擬合的相關(guān)資料,文中通過示例代碼詳細(xì)介紹了關(guān)于最小二乘法擬合直線和最小二乘法擬合曲線的實(shí)現(xiàn)過程,需要的朋友可以參考借鑒,下面來一起看看吧。2017-07-07Python函數(shù)isalnum用法示例小結(jié)
isalnum()函數(shù)是Python中的一個(gè)內(nèi)置函數(shù),用于判斷字符串是否只由數(shù)字和字母組成,其內(nèi)部實(shí)現(xiàn)原理比較簡單,只需遍歷字符串中的每一個(gè)字符即可,這篇文章主要介紹了Python函數(shù)isalnum用法介紹,需要的朋友可以參考下2024-01-01Python內(nèi)置函數(shù)ord()的實(shí)現(xiàn)示例
ord()函數(shù)是用于返回字符的Unicode碼點(diǎn),適用于處理文本和國際化應(yīng)用,它只能處理單個(gè)字符,超過一字符或非字符串類型會(huì)引發(fā)TypeError,示例代碼展示了如何使用ord()進(jìn)行字符轉(zhuǎn)換和比較2024-09-09python實(shí)現(xiàn)飛機(jī)大戰(zhàn)小游戲
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)飛機(jī)大戰(zhàn)游戲,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-11-11Python的mysql數(shù)據(jù)庫的更新如何實(shí)現(xiàn)
這篇文章主要介紹了Python的mysql數(shù)據(jù)庫的更新如何實(shí)現(xiàn)的相關(guān)資料,這里提供實(shí)例代碼,幫助大家理解應(yīng)用這部分知識(shí),需要的朋友可以參考下2017-07-07