欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

聊聊PyTorch中eval和no_grad的關系

 更新時間:2021年05月12日 08:32:34   作者:yanxiangtianji  
這篇文章主要介紹了聊聊PyTorch中eval和no_grad的關系,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

首先這兩者有著本質上區(qū)別

model.eval()是用來告知model內的各個layer采取eval模式工作。這個操作主要是應對諸如dropout和batchnorm這些在訓練模式下需要采取不同操作的特殊layer。訓練和測試的時候都可以開啟。

torch.no_grad()則是告知自動求導引擎不要進行求導操作。這個操作的意義在于加速計算、節(jié)約內存。但是由于沒有gradient,也就沒有辦法進行backward。所以只能在測試的時候開啟。

所以在evaluate的時候,需要同時使用兩者。

model = ...
dataset = ...
loss_fun = ...

# training
lr=0.001
model.train()
for x,y in dataset:
 model.zero_grad()
 p = model(x)
 l = loss_fun(p, y)
 l.backward()
 for p in model.parameters():
  p.data -= lr*p.grad
 
# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
 for x,y in dataset:
  p = model(x)
  l = loss_fun(p, y)
  sum_loss += l
print('total loss:', sum_loss)

另外no_grad還可以作為函數是修飾符來用,從而簡化代碼。

def train(model, dataset, loss_fun, lr=0.001):
 model.train()
 for x,y in dataset:
  model.zero_grad()
  p = model(x)
  l = loss_fun(p, y)
  l.backward()
  for p in model.parameters():
   p.data -= lr*p.grad
 
@torch.no_grad()
def test(model, dataset, loss_fun):
 sum_loss = 0.0
 model.eval()
 for x,y in dataset:
  p = model(x)
  l = loss_fun(p, y)
  sum_loss += l
 return sum_loss

# main block:
model = ...
dataset = ...
loss_fun = ...

# training
train()
# test
sum_loss = test()
print('total loss:', sum_loss)

補充:pytorch中model.train、model.eval以及torch.no_grad的用法

1、model.train()

啟用 BatchNormalization 和 Dropout

model.train() 讓model變成訓練模式,此時 dropout和batch normalization的操作在訓練起到防止網絡過擬合的問題

2、model.eval()

不啟用 BatchNormalization 和 Dropout

model.eval(),pytorch會自動把BN和DropOut固定住,而用訓練好的值。不然的話,一旦test的batch_size過小,很容易就會被BN層導致所生成圖片顏色失真極大

訓練完train樣本后,生成的模型model要用來測試樣本。在model(test)之前,需要加上model.eval(),否則的話,有輸入數據,即使不訓練,它也會改變權值。這是model中含有batch normalization層所帶來的的性質。

對于在訓練和測試時為什么要這樣做,可以從下面兩段話理解:

在訓練的時候, 會計算一個batch內的mean 和var, 但是因為是小batch小batch的訓練的,所以會采用加權或者動量的形式來將每個batch的 mean和var來累加起來,也就是說再算當前的batch的時候,其實當前的權重只是占了0.1, 之前所有訓練過的占了0.9的權重,這樣做的好處是不至于因為某一個batch太過奇葩而導致的訓練不穩(wěn)定。

好,現在假設訓練完成了, 那么在整個訓練集上面也得到了一個最終的”mean 和var”, BN層里面的參數也學習完了(如果指定學習的話),而現在需要測試了,測試的時候往往會一張圖一張圖的去測,這時候沒有batch而言了,對單獨一個數據做 mean和var是沒有意義的, 那么怎么辦,實際上在測試的時候BN里面用的mean和var就是訓練結束后的mean_final 和 val_final. 也可說是在測試的時候BN就是一個變換。所以在用pytorch的時候要注意這一點,在訓練之前要有model.train() 來告訴網絡現在開啟了訓練模式,在eval的時候要用”model.eval()”, 用來告訴網絡現在要進入測試模式了.因為這兩種模式下BN的作用是不同的。

3、torch.no_grad()

這條語句的作用是:在測試時不進行梯度的計算,這樣可以在測試時有效減小顯存的占用,以免發(fā)生顯存溢出(OOM)。

這條語句通常加在網絡預測的那條代碼上。

4、pytorch中model.eval()和“with torch.no_grad()區(qū)別

兩者區(qū)別

在PyTorch中進行validation時,會使用model.eval()切換到測試模式,在該模式下,

主要用于通知dropout層和batchnorm層在train和val模式間切換

在train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); batchnorm層會繼續(xù)計算數據的mean和var等參數并更新。

在val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。

該模式不會影響各層的gradient計算行為,即gradient計算和存儲與training模式一樣,只是不進行反傳(backprobagation)

而with torch.zero_grad()則主要是用于停止autograd模塊的工作,以起到加速和節(jié)省顯存的作用,具體行為就是停止gradient計算,從而節(jié)省了GPU算力和顯存,但是并不會影響dropout和batchnorm層的行為。

使用場景

如果不在意顯存大小和計算時間的話,僅僅使用model.eval()已足夠得到正確的validation的結果;而with torch.zero_grad()則是更進一步加速和節(jié)省gpu空間(因為不用計算和存儲gradient),從而可以更快計算,也可以跑更大的batch來測試。

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。

相關文章

  • Python是編譯運行的驗證方法

    Python是編譯運行的驗證方法

    這篇文章主要介紹了Python是編譯運行的驗證方法,本文講解了一個小方法來驗證Python是編譯運行還是解釋運行,需要的朋友可以參考下
    2015-01-01
  • 關于pytorch訓練分類器

    關于pytorch訓練分類器

    這篇文章主要介紹了關于pytorch訓練分類器問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-09-09
  • 簡單了解django索引的相關知識

    簡單了解django索引的相關知識

    這篇文章主要介紹了簡單了解django索引的相關知識,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-07-07
  • 在win和Linux系統(tǒng)中python命令行運行的不同

    在win和Linux系統(tǒng)中python命令行運行的不同

    本文給大家分享的是作者在在win和Linux系統(tǒng)中python命令行運行的不同的解決方法,有相同需求的小伙伴可以參考下
    2016-07-07
  • pytorch 實現在預訓練模型的 input上增減通道

    pytorch 實現在預訓練模型的 input上增減通道

    今天小編就為大家分享一篇pytorch 實現在預訓練模型的 input上增減通道,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-01-01
  • 詳解python OpenCV學習筆記之直方圖均衡化

    詳解python OpenCV學習筆記之直方圖均衡化

    本篇文章主要介紹了詳解python OpenCV學習筆記之直方圖均衡化,小編覺得挺不錯的,現在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-02-02
  • python 爬取微信文章

    python 爬取微信文章

    本文給大家分享的是使用python通過搜狗入口,爬取微信文章的小程序,非常的簡單實用,有需要的小伙伴可以參考下
    2016-01-01
  • opencv+pyQt5實現圖片閾值編輯器/尋色塊閾值利器

    opencv+pyQt5實現圖片閾值編輯器/尋色塊閾值利器

    這篇文章主要介紹了opencv+pyQt5實現圖片閾值編輯器/尋色塊閾值利器,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-11-11
  • Python利用Flask-Mail實現發(fā)送郵件詳解

    Python利用Flask-Mail實現發(fā)送郵件詳解

    Flask?的擴展包?Flask?-?Mail?通過包裝了?Python?內置的smtplib包,可以用在?Flask?程序中發(fā)送郵件。本文將利用這特性實現郵件發(fā)送功能,感興趣的可以了解一下
    2022-08-08
  • python實現的各種排序算法代碼

    python實現的各種排序算法代碼

    python實現的各種排序算法,包括選擇排序、冒泡排序、插入排序、歸并排序等,學習python的朋友可以參考下
    2013-03-03

最新評論