使用with torch.no_grad():顯著減少測(cè)試時(shí)顯存占用
with torch.no_grad():顯著減少測(cè)試時(shí)顯存占用
問題描述
將訓(xùn)練好的模型拿來做inference,發(fā)現(xiàn)顯存被占滿,無法進(jìn)行后續(xù)操作,但按理說不應(yīng)該出現(xiàn)這種情況。
RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 7.93 GiB total capacity; 6.94 GiB already allocated; 10.56 MiB free; 7.28 GiB reserved in total by PyTorch)
解決方案
經(jīng)過排查代碼,發(fā)現(xiàn)做inference時(shí),各模型雖然已經(jīng)設(shè)置為eval()模式,但是并沒有取消網(wǎng)絡(luò)生成計(jì)算圖這一操作,這就導(dǎo)致網(wǎng)絡(luò)在單純做前向傳播時(shí)也生成了計(jì)算圖,從而消耗了大量顯存。
所以,將模型前向傳播的代碼放到with torch.no_grad()下,就能使pytorch不生成計(jì)算圖,從而節(jié)省不少顯存
with torch.no_grad(): # 代碼塊 outputs = model(inputs) # 代碼塊
經(jīng)過修改,再進(jìn)行inference就沒有遇到顯存不夠的情況了。
此時(shí)顯存占用顯著降低,只占用5600MB左右(3卡)。
model.eval()和torch.no_grad()
model.eval()
- 使用model.eval()切換到測(cè)試模式,不會(huì)更新模型的k,b參數(shù)
- 通知dropout層和batchnorm層在train和val中間進(jìn)行切換在。train模式,dropout層會(huì)按照設(shè)定的參數(shù)p設(shè)置保留激活單元的概率(保留概率=p,比如keep_prob=0.8),batchnorm層會(huì)繼續(xù)計(jì)算數(shù)據(jù)的mean和var并進(jìn)行更新。在val模式下,dropout層會(huì)讓所有的激活單元都通過,而batchnorm層會(huì)停止計(jì)算和更新mean和var,直接使用在訓(xùn)練階段已經(jīng)學(xué)出的mean和var值
- model.eval()不會(huì)影響各層的gradient計(jì)算行為,即gradient計(jì)算和存儲(chǔ)與training模式一樣,只是不進(jìn)行反向傳播(backprobagation),即只設(shè)置了model.eval()pytorch依舊會(huì)生成計(jì)算圖,占用顯存,只是不使用計(jì)算圖來進(jìn)行反向傳播。
torch.no_grad()
首先從requires_grad講起:
requires_grad
- 在pytorch中,tensor有一個(gè)requires_grad參數(shù),如果設(shè)置為True,則反向傳播時(shí),該tensor就會(huì)自動(dòng)求導(dǎo),并且保存在計(jì)算圖中。tensor的requires_grad的屬性默認(rèn)為False,若一個(gè)節(jié)點(diǎn)(葉子變量:自己創(chuàng)建的tensor)requires_grad被設(shè)置為True,那么所有依賴它的節(jié)點(diǎn)requires_grad都為True(即使其他相依賴的tensor的requires_grad = False)
- 當(dāng)requires_grad設(shè)置為False時(shí),反向傳播時(shí)就不會(huì)自動(dòng)求導(dǎo)了,也就不會(huì)生成計(jì)算圖,而GPU也不用再保存計(jì)算圖,因此大大節(jié)約了顯存或者說內(nèi)存。
with torch.no_grad
- 在該模塊下,所有計(jì)算得出的tensor的requires_grad都自動(dòng)設(shè)置為False。
- 即使一個(gè)tensor(命名為x)的requires_grad = True,在with torch.no_grad計(jì)算,由x得到的新tensor(命名為w-標(biāo)量)requires_grad也為False,且grad_fn也為None,即不會(huì)對(duì)w求導(dǎo)。
例子如下所示:
x = torch.randn(10, 5, requires_grad = True) y = torch.randn(10, 5, requires_grad = True) z = torch.randn(10, 5, requires_grad = True) with torch.no_grad(): w = x + y + z print(w.requires_grad) print(w.grad_fn) print(w.requires_grad) False None False
也就是說,在with torch.no_grad結(jié)構(gòu)中的所有tensor的requires_grad屬性會(huì)被強(qiáng)行設(shè)置為false,如果前向傳播過程在該結(jié)構(gòu)中,那么inference過程中都不會(huì)產(chǎn)生計(jì)算圖,從而節(jié)省不少顯存。
版本問題
問題描述
volatile was removed and now has no effect. Use with torch.no_grad(): instead
源代碼
captions = Variable(torch.from_numpy(captions), volatile=True)
原因
1.在torch版本中volatile已經(jīng)被移除。在pytorch 0.4.0之前 input= Variable(input, volatile=True) 設(shè)置volatile為True ,只要是一個(gè)輸入為volatile,則輸出也是volatile的,它能夠保證不存在中間狀態(tài);但是在pytorch 0.4.0之后取消了volatile的機(jī)制,被替換成torch.no_grad()函數(shù)
2.torch.no_grad() 是一個(gè)上下文管理器。在使用pytorch時(shí),并不是所有的操作都需要進(jìn)行計(jì)算圖的生成(計(jì)算過程的構(gòu)建,以便梯度反向傳播等操作)。而對(duì)于tensor的計(jì)算操作,默認(rèn)是要進(jìn)行計(jì)算圖的構(gòu)建的,在這種情況下,可以使用 with torch.no_grad():,強(qiáng)制之后的內(nèi)容不進(jìn)行計(jì)算圖構(gòu)建。在torch.no_grad() 會(huì)影響pytorch的反向傳播機(jī)制,在測(cè)試時(shí)因?yàn)榇_定不會(huì)使用到反向傳播因此 這種模式可以幫助節(jié)省內(nèi)存空間。同理對(duì)于 torch.set_grad_enable(grad_mode)也是這樣
with torch.no_grad()解答
with torch.no_grad()簡(jiǎn)述及例子
torch.no_grad()是PyTorch中的一個(gè)上下文管理器(context manager),用于指定在其內(nèi)部的代碼塊中不進(jìn)行梯度計(jì)算。當(dāng)你不需要計(jì)算梯度時(shí),可以使用該上下文管理器來提高代碼的執(zhí)行效率,尤其是在推斷(inference)階段和梯度裁剪(grad clip)階段的時(shí)候。
在使用torch.autograd進(jìn)行自動(dòng)求導(dǎo)時(shí),PyTorch會(huì)默認(rèn)跟蹤并計(jì)算張量的梯度。然而,有時(shí)我們只關(guān)心前向傳播的結(jié)果,而不需要計(jì)算梯度,這時(shí)就可以使用torch.no_grad()來關(guān)閉自動(dòng)求導(dǎo)功能。
在torch.no_grad()的上下文中執(zhí)行的張量運(yùn)算不會(huì)被跟蹤,也不會(huì)產(chǎn)生梯度信息,從而提高計(jì)算效率并節(jié)省內(nèi)存。
下面舉例一個(gè)在關(guān)閉梯度跟蹤torch.no_grad()后仍然要更新梯度矩陣y.backward()的錯(cuò)誤例子:
import torch # 創(chuàng)建兩個(gè)張量 x = torch.tensor([2.0], requires_grad=True) w = torch.tensor([3.0], requires_grad=True) # 在計(jì)算階段使用 torch.no_grad() with torch.no_grad(): ? ? y = x * w # 輸出結(jié)果,不會(huì)計(jì)算梯度 print(y) ?# tensor([6.]) # 嘗試對(duì) y 進(jìn)行反向傳播(會(huì)報(bào)錯(cuò)) y.backward() ?# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
在上面的例子中,我們通過將x和w張量的requires_grad屬性設(shè)置為True,表示我們希望計(jì)算它們的梯度。然而,在torch.no_grad()的上下文中,對(duì)于y的計(jì)算不會(huì)被跟蹤,也不會(huì)生成梯度信息。因此,在執(zhí)行y.backward()時(shí)會(huì)報(bào)錯(cuò)。
with torch.no_grad()在訓(xùn)練階段使用
with torch.no_grad()常見于eval()驗(yàn)證集和測(cè)試集中,但是有時(shí)候我們?nèi)匀粫?huì)在train()訓(xùn)練集中看到,如下:
@d2l.add_to_class(d2l.Trainer) ?#@save def prepare_batch(self, batch): ? ? return batch @d2l.add_to_class(d2l.Trainer) ?#@save def fit_epoch(self): ? ? self.model.train() ? ? for batch in self.train_dataloader: ? ? ? ? loss = self.model.training_step(self.prepare_batch(batch)) ? ? ? ? self.optim.zero_grad() ? ? ? ? with torch.no_grad(): ? ? ? ? ? ? loss.backward() ? ? ? ? ? ? if self.gradient_clip_val > 0: ?# To be discussed later ? ? ? ? ? ? ? ? self.clip_gradients(self.gradient_clip_val, self.model) ? ? ? ? ? ? self.optim.step() ? ? ? ? self.train_batch_idx += 1 ? ? if self.val_dataloader is None: ? ? ? ? return ? ? self.model.eval() ? ? for batch in self.val_dataloader: ? ? ? ? with torch.no_grad(): ? ? ? ? ? ? self.model.validation_step(self.prepare_batch(batch)) ? ? ? ? self.val_batch_idx += 1
這是因?yàn)槲覀冞M(jìn)行了梯度裁剪,在上述代碼中,torch.no_grad()的作用是在計(jì)算梯度之前執(zhí)行梯度裁剪操作。loss.backward()會(huì)計(jì)算損失的梯度,但在這個(gè)特定的上下文中,我們不希望梯度裁剪的操作被跟蹤和計(jì)算梯度。因此,我們使用torch.no_grad()將裁剪操作放在一個(gè)沒有梯度跟蹤的上下文中,以避免計(jì)算和存儲(chǔ)與梯度裁剪無關(guān)的梯度信息。
而梯度的記錄和跟蹤實(shí)際上已經(jīng)在loss = self.model.training_step(self.prepare_batch(batch))中完成了(類似output = model(input)),而loss.backward()只是計(jì)算梯度并更新了model的梯度矩陣。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- PyTorch中torch.load()的用法和應(yīng)用
- python中torch.load中的map_location參數(shù)使用
- Pytorch中的torch.nn.Linear()方法用法解讀
- Pytorch中的torch.where函數(shù)使用
- python中的List sort()與torch.sort()
- 關(guān)于torch.scatter與torch_scatter庫的使用整理
- PyTorch函數(shù)torch.cat與torch.stac的區(qū)別小結(jié)
- pytorch.range()和pytorch.arange()的區(qū)別及說明
- PyTorch中torch.save()的用法和應(yīng)用小結(jié)
相關(guān)文章
python數(shù)據(jù)清洗中的時(shí)間格式化實(shí)現(xiàn)
本文主要介紹了python數(shù)據(jù)清洗中的時(shí)間格式化實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-05-05Python實(shí)現(xiàn)購物評(píng)論文本情感分析操作【基于中文文本挖掘庫snownlp】
這篇文章主要介紹了Python實(shí)現(xiàn)購物評(píng)論文本情感分析操作,結(jié)合實(shí)例形式分析了Python使用中文文本挖掘庫snownlp操作中文文本進(jìn)行感情分析的相關(guān)實(shí)現(xiàn)技巧與注意事項(xiàng),需要的朋友可以參考下2018-08-08python實(shí)現(xiàn)定時(shí)同步本機(jī)與北京時(shí)間的方法
這篇文章主要介紹了python實(shí)現(xiàn)定時(shí)同步本機(jī)與北京時(shí)間的方法,涉及Python針對(duì)時(shí)間的操作技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-03-03教你用Python為二年級(jí)的學(xué)生批量生成數(shù)學(xué)題
這兩天在學(xué)習(xí)pthon,正好遇到老師布置的暑假作業(yè),需要家長(zhǎng)給還在出試卷,下面這篇文章主要給大家介紹了關(guān)于如何用Python為二年級(jí)的學(xué)生批量生成數(shù)學(xué)題的相關(guān)資料,需要的朋友可以參考下2023-02-02Django中日期處理注意事項(xiàng)與自定義時(shí)間格式轉(zhuǎn)換詳解
這篇文章主要給大家介紹了關(guān)于Django中日期處理注意事項(xiàng)與自定義時(shí)間格式轉(zhuǎn)換的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2018-08-08使用Python模擬操作windows應(yīng)用窗口詳解
在日常工作中,我們經(jīng)常遇到需要進(jìn)行大量重復(fù)性任務(wù)的情況,這篇文章將介紹如何使用 Python 模擬操作記事本,感興趣的小伙伴可以了解下2025-02-02