Pytorch中的model.train()?和?model.eval()?原理與用法解析
Pytorch中的model.train() 和 model.eval() 原理與用法
一、兩種模式
pytorch可以給我們提供兩種方式來(lái)切換訓(xùn)練和評(píng)估(推斷)的模式,分別是:model.train()
和 model.eval()
。
一般用法是:在訓(xùn)練開(kāi)始之前寫(xiě)上 model.trian() ,在測(cè)試時(shí)寫(xiě)上 model.eval() 。
二、功能
1. model.train()
在使用 pytorch 構(gòu)建神經(jīng)網(wǎng)絡(luò)的時(shí)候,訓(xùn)練過(guò)程中會(huì)在程序上方添加一句model.train(),作用是 啟用 batch normalization 和 dropout 。
如果模型中有BN層(Batch Normalization)和 Dropout ,需要在 訓(xùn)練時(shí) 添加 model.train()。
model.train() 是保證 BN 層能夠用到 每一批數(shù)據(jù) 的均值和方差。對(duì)于 Dropout,model.train() 是 隨機(jī)取一部分 網(wǎng)絡(luò)連接來(lái)訓(xùn)練更新參數(shù)。
2. model.eval()
model.eval()的作用是 不啟用 Batch Normalization 和 Dropout。
如果模型中有 BN 層(Batch Normalization)和 Dropout,在 測(cè)試時(shí) 添加 model.eval()。
model.eval() 是保證 BN 層能夠用 全部訓(xùn)練數(shù)據(jù) 的均值和方差,即測(cè)試過(guò)程中要保證 BN 層的均值和方差不變。對(duì)于 Dropout,model.eval() 是利用到了 所有 網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
為什么測(cè)試時(shí)要用 model.eval() ?
訓(xùn)練完 train 樣本后,生成的模型 model 要用來(lái)測(cè)試樣本了。在 model(test) 之前,需要加上model.eval(),否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會(huì)改變權(quán)值。這是 model 中含有 BN 層和 Dropout 所帶來(lái)的的性質(zhì)。
eval() 時(shí),pytorch 會(huì)自動(dòng)把 BN 和 DropOut 固定住,不會(huì)取平均,而是用訓(xùn)練好的值。
不然的話,一旦 test 的 batch_size 過(guò)小,很容易就會(huì)被 BN 層導(dǎo)致生成圖片顏色失真極大。
eval() 在非訓(xùn)練的時(shí)候是需要加的,沒(méi)有這句代碼,一些網(wǎng)絡(luò)層的值會(huì)發(fā)生變動(dòng),不會(huì)固定,你神經(jīng)網(wǎng)絡(luò)每一次生成的結(jié)果也是不固定的,生成質(zhì)量可能好也可能不好。
也就是說(shuō),測(cè)試過(guò)程中使用model.eval(),這時(shí)神經(jīng)網(wǎng)絡(luò)會(huì) 沿用 batch normalization 的值,而并 不使用 dropout。
3. 總結(jié)與對(duì)比
如果模型中有 BN 層(Batch Normalization)和 Dropout,需要在訓(xùn)練時(shí)添加 model.train(),在測(cè)試時(shí)添加 model.eval()。
其中 model.train() 是保證 BN 層用每一批數(shù)據(jù)的均值和方差,而 model.eval() 是保證 BN 用全部訓(xùn)練數(shù)據(jù)的均值和方差;
而對(duì)于 Dropout,model.train() 是隨機(jī)取一部分網(wǎng)絡(luò)連接來(lái)訓(xùn)練更新參數(shù),而 model.eval() 是利用到了所有網(wǎng)絡(luò)連接。
三、Dropout 簡(jiǎn)介
dropout 常常用于抑制過(guò)擬合。
設(shè)置Dropout時(shí),torch.nn.Dropout(0.5),這里的 0.5 是指該層(layer)的神經(jīng)元在每次迭代訓(xùn)練時(shí)會(huì)隨機(jī)有 50% 的可能性被丟棄(失活),不參與訓(xùn)練。也就是將上一層數(shù)據(jù)減少一半傳播。
參考鏈接
- PyTorch中train()方法的作用是什么
- 【pytorch】model.train()和model.evel()的用法
- pytorch中net.eval() 和net.train()的使用
- Pytorch學(xué)習(xí)筆記11----model.train()與model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函數(shù)、nn.Linear淺析、輸出整個(gè)tensor的方法
- 好文:Pytorch:model.train()和model.eval()用法和區(qū)別,以及model.eval()和torch.no_grad()的區(qū)別
補(bǔ)充:pytroch:model.train()、model.eval()的使用
前言:最近在把兩個(gè)模型的代碼整合到一起,發(fā)現(xiàn)有一個(gè)模型的代碼整合后性能大不如前,但基本上是源碼遷移,找了一天原因才發(fā)現(xiàn)是因?yàn)閙odel.eval()和model.train()放錯(cuò)了位置?。?!故在此介紹一下pytroch框架下model.train()、model.eval()的作用和不同點(diǎn)。
一、model.train、model.eval
1.model.train和model.eval放在代碼什么位置
簡(jiǎn)單的說(shuō):
model.train
放在網(wǎng)絡(luò)訓(xùn)練前,model.eval
放在網(wǎng)絡(luò)測(cè)試前。
常見(jiàn)的位置擺放錯(cuò)誤(也是我犯的錯(cuò)誤)有把model.train()
放在for epoch in range(epoch):
前面,同時(shí)在test或者val(測(cè)試或者評(píng)估函數(shù))中只放置model.eval
,這就導(dǎo)致了只有第一個(gè)epoch模型訓(xùn)練是使用了model.train()
,之后的epoch模型訓(xùn)練時(shí)都采用model.eval()
.可能會(huì)影響訓(xùn)練好模型的性能。
修改方式:可以在test函數(shù)里return前面添加model.train()
或者把model.train()
放到for epoch in range(epoch):
語(yǔ)句下面。
model.train() for epoch in range(epoch): for train_batch in train_loader: ... zhibiao = test(epoch, test_loader, model) def test(epoch, test_loader, model): model.eval() for test_batch in test_loader: ... return zhibiao
2.model.train和model.eval有什么作用
model.train()和model.eval()的區(qū)別主要在于Batch Normalization和Dropout兩層。
如果模型中有BN層(Batch Normalization)和Dropout,在測(cè)試時(shí)添加model.eval()。model.eval()是保證BN層能夠用全部訓(xùn)練數(shù)據(jù)的均值和方差,即測(cè)試過(guò)程中要保證BN層的均值和方差不變。對(duì)于Dropout,model.eval()是利用到了所有網(wǎng)絡(luò)連接,即不進(jìn)行隨機(jī)舍棄神經(jīng)元。
下面是model.train 和model.eval的源碼,可以看到是利用self.training = mode
來(lái)判斷是使用train還是eval。這個(gè)參數(shù)將傳遞到一些常用層,比如dropout、BN層等。
def train(self: T, mode: bool = True) -> T: r"""Sets the module in training mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self def eval(self: T) -> T: r"""Sets the module in evaluation mode. This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc. This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. Returns: Module: self """ return self.train(False)
拿dropout層的源碼舉例,可以看到傳遞了self.training這個(gè)參數(shù)。
class Dropout(_DropoutNd): r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p` using samples from a Bernoulli distribution. Each channel will be zeroed out independently on every forward call. This has proven to be an effective technique for regularization and preventing the co-adaptation of neurons as described in the paper `Improving neural networks by preventing co-adaptation of feature detectors`_ . Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during training. This means that during evaluation the module simply computes an identity function. Args: p: probability of an element to be zeroed. Default: 0.5 inplace: If set to ``True``, will do this operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`. Input can be of any shape - Output: :math:`(*)`. Output is of the same shape as input Examples:: >>> m = nn.Dropout(p=0.2) >>> input = torch.randn(20, 16) >>> output = m(input) .. _Improving neural networks by preventing co-adaptation of feature detectors: https://arxiv.org/abs/1207.0580 """ def forward(self, input: Tensor) -> Tensor: return F.dropout(input, self.p, self.training, self.inplace)
3.為什么主要區(qū)別在于BN層和dropout層
在BN層中,主要涉及到四個(gè)需要更新的參數(shù),分別是running_mean,running_var,weight,bias。這里的weight,bias是Pytorch官方實(shí)現(xiàn)中的叫法,有點(diǎn)誤導(dǎo)人,其實(shí)weight就是gamma,bias就是beta。當(dāng)然它這樣的叫法也符合實(shí)際的應(yīng)用場(chǎng)景。其實(shí)gamma,beta就是對(duì)規(guī)范化后的值進(jìn)行一個(gè)加權(quán)求和操作running_mean,running_var是當(dāng)前所求得的所有batch_size下的均值和方差,每經(jīng)過(guò)一個(gè)mini_batch我們都會(huì)更新running_mean,running_var.為什么要更新它?因?yàn)闇y(cè)試的時(shí)候,往往是一個(gè)一個(gè)的圖像feed至網(wǎng)絡(luò)的,如果你在這里對(duì)其進(jìn)行計(jì)算均值方差顯然是不合理的,所以model.eval()這個(gè)語(yǔ)句就是控制BN層中的running_mean,running_std不更新。采用訓(xùn)練結(jié)束后的running_mean,running_std來(lái)規(guī)范化該張圖像。
dropout層在訓(xùn)練過(guò)程中會(huì)隨機(jī)舍棄一些神經(jīng)元用來(lái)提高性能,但測(cè)試過(guò)程中如果還是測(cè)試的模型還是和訓(xùn)練時(shí)一樣隨機(jī)舍棄了一些神經(jīng)元(不是原模型)這就和測(cè)試的本意相違背。因?yàn)闇y(cè)試的模型應(yīng)該是我們最終得到的模型,而這個(gè)模型應(yīng)該是一個(gè)完整的模型。
4.BN層和dropout層的作用
既然都講到這了,不了解一些BN層和dropout層的作用就說(shuō)不過(guò)去了。
BN層的原理和作用建議讀一下這篇博客:神經(jīng)網(wǎng)絡(luò)中BN層的原理與作用
dropout是指在深度學(xué)習(xí)網(wǎng)絡(luò)的訓(xùn)練過(guò)程中,對(duì)于神經(jīng)網(wǎng)絡(luò)單元,按照一定的概率將其暫時(shí)從網(wǎng)絡(luò)中丟棄。注意是暫時(shí),對(duì)于隨機(jī)梯度下降來(lái)說(shuō),由于是隨機(jī)丟棄,故而每一個(gè)mini-batch都在訓(xùn)練不同的網(wǎng)絡(luò)。
大規(guī)模的神經(jīng)網(wǎng)絡(luò)有兩個(gè)缺點(diǎn):費(fèi)時(shí)、容易過(guò)擬合
Dropout的出現(xiàn)很好的可以解決這個(gè)問(wèn)題,每次做完dropout,相當(dāng)于從原始的網(wǎng)絡(luò)中找到一個(gè)更瘦的網(wǎng)絡(luò)。因而,對(duì)于一個(gè)有N個(gè)節(jié)點(diǎn)的神經(jīng)網(wǎng)絡(luò),有了dropout后,就可以看做是2^n個(gè)模型的集合了,但此時(shí)要訓(xùn)練的參數(shù)數(shù)目卻是不變的,這就解決了費(fèi)時(shí)的問(wèn)題。
將dropout比作是有性繁殖,將基因隨機(jī)進(jìn)行拆分,可以將優(yōu)秀的基因傳下來(lái),并且降低基因之間的聯(lián)合適應(yīng)性,使得復(fù)雜的大段大段基因聯(lián)合適應(yīng)性變成比較小的一個(gè)一個(gè)小段基因的聯(lián)合適應(yīng)性。
dropout也能達(dá)到同樣的效果,它強(qiáng)迫一個(gè)神經(jīng)單元,和隨機(jī)挑選出來(lái)的其他神經(jīng)單元共同工作,達(dá)到好的效果。消除減弱了神經(jīng)元節(jié)點(diǎn)間的聯(lián)合適應(yīng)性,增強(qiáng)了泛化能力。
參考鏈接
pytorch中model.train()和model.eval()的區(qū)別
BN層(Pytorch)
神經(jīng)網(wǎng)絡(luò)中BN層的原理與作用————這篇博客寫(xiě)的賊棒
深度學(xué)習(xí)中Dropout的作用和原理
pytorch之model.train()和model.eval()
概要
使用PyTorch進(jìn)行訓(xùn)練和測(cè)試時(shí)一定注意要把實(shí)例化的model指定train/eval
eval()
時(shí),框架會(huì)自動(dòng)把 BN
和 DropOut
固定住,不會(huì)取平均,而是用訓(xùn)練好的值,不然的話,一旦test的batch_size過(guò)小,很容易就會(huì)被BN層導(dǎo)致生成圖片顏色失真極大!
model.train()
啟用 BatchNormalization
和 Dropout
model.eval()
不啟用 BatchNormalization
和 Dropout
訓(xùn)練完train樣本后,生成的模型model要用來(lái)測(cè)試樣本。在model(test)
之前,需要加上model.eval()
,否則的話,有輸入數(shù)據(jù),即使不訓(xùn)練,它也會(huì)改變權(quán)值。這是model中含有batch normalization
層所帶來(lái)的的性質(zhì)。
Batch Normalization
BN的作用主要是對(duì)網(wǎng)絡(luò)中間的每層進(jìn)行歸一化處理,保證每層提取的特征分布不會(huì)被破壞。
訓(xùn)練時(shí)是針對(duì)每個(gè)mini-batch的,但是測(cè)試是針對(duì)單張圖片的,即不存在batch的概念。由于網(wǎng)絡(luò)訓(xùn)練完成后參數(shù)是固定的,因此每個(gè)batch的均值和方差是不變的.
Dropout
Dropout能夠克服Overfitting,在每個(gè)訓(xùn)練批次中,通過(guò)忽略一半的特征檢測(cè)器,可以明顯的減少過(guò)擬合現(xiàn)象。詳細(xì)見(jiàn)文章:《Dropout: A Simple Way to Prevent Neural Networks from Overtting》
總結(jié)
如果模型中有BN層(Batch Normalization)和Dropout,需要在訓(xùn)練時(shí)添加model.train(),在測(cè)試時(shí)添加model.eval()。
其中model.train()是保證BN層用每一批數(shù)據(jù)的均值和方差,而model.eval()是保證BN用全部訓(xùn)練數(shù)據(jù)的均值和方差;
而對(duì)于Dropout,model.train()是隨機(jī)取一部分網(wǎng)絡(luò)連接來(lái)訓(xùn)練更新參數(shù),而model.eval()是利用到了所有網(wǎng)絡(luò)連接。
到此這篇關(guān)于Pytorch中的model.train() 和 model.eval() 原理與用法的文章就介紹到這了,更多相關(guān)Pytorch model.train() 和 model.eval()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python網(wǎng)絡(luò)編程之HTTP協(xié)議的python應(yīng)用
HTTP是在網(wǎng)絡(luò)上傳輸HTML的協(xié)議,用于瀏覽器和服務(wù)器的通信,這篇文章主要介紹了Python網(wǎng)絡(luò)編程之HTTP協(xié)議的python應(yīng)用,需要的朋友可以參考下2022-11-11python調(diào)用機(jī)器喇叭發(fā)出蜂鳴聲(Beep)的方法
這篇文章主要介紹了python調(diào)用機(jī)器喇叭發(fā)出蜂鳴聲(Beep)的方法,實(shí)例分析了Python調(diào)用winsound模塊的使用技巧,需要的朋友可以參考下2015-03-03python打包exe文件并隱藏執(zhí)行CMD命令窗口問(wèn)題
這篇文章主要介紹了python打包exe文件并隱藏執(zhí)行CMD命令窗口問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01使用OpenCV實(shí)現(xiàn)仿射變換—旋轉(zhuǎn)功能
這篇文章主要介紹了在OpenCV里實(shí)現(xiàn)仿射變換——旋轉(zhuǎn)功能,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-08-08詳解Python如何利用Pandas與NumPy進(jìn)行數(shù)據(jù)清洗
許多數(shù)據(jù)科學(xué)家認(rèn)為獲取和清理數(shù)據(jù)的初始步驟占工作的 80%,花費(fèi)大量時(shí)間來(lái)清理數(shù)據(jù)集并將它們歸結(jié)為可以使用的形式。本文將利用 Python 的 Pandas和 NumPy 庫(kù)來(lái)清理數(shù)據(jù),需要的可以參考一下2022-04-04NumPy數(shù)組排序、過(guò)濾與隨機(jī)數(shù)生成詳解
這篇文章主要詳細(xì)給大家介紹了NumPy數(shù)組排序、過(guò)濾與隨機(jī)數(shù)生成,文中通過(guò)代碼示例給大家講解的非常詳細(xì),對(duì)大家學(xué)習(xí)NumPy有一定的幫助,需要的朋友可以參考下2024-05-05python字符串string的內(nèi)置方法實(shí)例詳解
這篇文章主要介紹了python字符串string的內(nèi)置方法,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2018-05-05