PyTorch中torch.no_grad()用法舉例詳解
前言
torch.no_grad() 是 PyTorch 中的一個(gè)上下文管理器,用于在上下文中臨時(shí)禁用自動(dòng)梯度計(jì)算。它在模型評(píng)估或推理階段非常有用,因?yàn)樵谶@些階段,我們通常不需要計(jì)算梯度。禁用梯度計(jì)算可以減少內(nèi)存消耗,并加快計(jì)算速度。
基本概念
在 PyTorch 中,每次對(duì) requires_grad=True 的張量進(jìn)行操作時(shí),PyTorch 會(huì)構(gòu)建一個(gè)計(jì)算圖(computation graph),用于計(jì)算反向傳播的梯度。這對(duì)訓(xùn)練模型是必要的,但在評(píng)估或推理時(shí)不需要。因此,我們可以使用 torch.no_grad() 來(lái)臨時(shí)禁用這些計(jì)算圖的構(gòu)建和梯度計(jì)算。
用法
torch.no_grad() 的使用非常簡(jiǎn)單。只需要將不需要梯度計(jì)算的代碼塊放在 with torch.no_grad(): 下即可。
示例代碼
以下是一個(gè)使用 torch.no_grad() 的示例:
import torch # 創(chuàng)建一個(gè)張量,并設(shè)置 requires_grad=True 以便記錄梯度 x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) # 在 torch.no_grad() 上下文中禁用梯度計(jì)算 with torch.no_grad(): y = x + 2 print(y) # 此時(shí),x 的 requires_grad 屬性仍然為 True,但 y 的 requires_grad 屬性為 False print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)
詳細(xì)解釋
創(chuàng)建張量并設(shè)置 requires_grad=True:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
創(chuàng)建一個(gè)包含三個(gè)元素的張量 x。
設(shè)置 requires_grad=True,告訴 PyTorch 需要為該張量記錄梯度。
禁用梯度計(jì)算:
with torch.no_grad(): y = x + 2 print(y)
進(jìn)入 torch.no_grad() 上下文,臨時(shí)禁用梯度計(jì)算。
在上下文中,對(duì) x 進(jìn)行加法操作,得到新的張量 y。
打印 y,此時(shí) y 的 requires_grad 屬性為 False。
查看 requires_grad 屬性:
print("x 的 requires_grad:", x.requires_grad) print("y 的 requires_grad:", y.requires_grad)
打印 x 的 requires_grad 屬性,仍然為 True。
打印 y 的 requires_grad 屬性,已被禁用為 False。
使用場(chǎng)景
模型評(píng)估
在評(píng)估模型性能時(shí),不需要計(jì)算梯度。使用 torch.no_grad() 可以提高評(píng)估速度和減少內(nèi)存消耗。
model.eval() # 切換到評(píng)估模式 with torch.no_grad(): for data in validation_loader: outputs = model(data) # 計(jì)算評(píng)估指標(biāo)
模型推理
在部署和推理階段,只需要前向傳播,不需要反向傳播,因此可以使用 torch.no_grad()。
with torch.no_grad(): outputs = model(inputs) predicted = torch.argmax(outputs, dim=1)
初始化權(quán)重或其他不需要梯度的操作
在某些初始化或操作中,不需要梯度計(jì)算。
with torch.no_grad(): model.weight.fill_(1.0) # 直接修改權(quán)重
小結(jié)
torch.no_grad() 是一個(gè)用于禁用梯度計(jì)算的上下文管理器,適用于模型評(píng)估、推理等不需要梯度計(jì)算的場(chǎng)景。使用 torch.no_grad() 可以顯著減少內(nèi)存使用和加速計(jì)算。通過(guò)理解和合理使用 torch.no_grad(),可以使得模型評(píng)估和推理更加高效和穩(wěn)定。
額外注意事項(xiàng)
訓(xùn)練模式與評(píng)估模式:
在使用 torch.no_grad() 時(shí),通常還會(huì)將模型設(shè)置為評(píng)估模式(model.eval()),以確保某些層(如 dropout 和 batch normalization)在推理時(shí)的行為與訓(xùn)練時(shí)不同。
嵌套使用:
torch.no_grad() 可以嵌套使用,內(nèi)層的 torch.no_grad() 仍然會(huì)禁用梯度計(jì)算。
with torch.no_grad(): with torch.no_grad(): y = x + 2 print(y)
恢復(fù)梯度計(jì)算:
在 torch.no_grad() 上下文管理器退出后,梯度計(jì)算會(huì)自動(dòng)恢復(fù),不需要額外操作。
with torch.no_grad(): y = x + 2 print(y) # 這里梯度計(jì)算恢復(fù) z = x * 2 print(z.requires_grad) # True
通過(guò)合理使用 torch.no_grad(),可以在不需要梯度計(jì)算的場(chǎng)景中提升性能并節(jié)省資源。
總結(jié)
到此這篇關(guān)于PyTorch中torch.no_grad()用法舉例詳解的文章就介紹到這了,更多相關(guān)PyTorch torch.no_grad()詳解內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中4種實(shí)現(xiàn)數(shù)值的交換方式
這篇文章主要介紹了Python中4種實(shí)現(xiàn)數(shù)值的交換方式,文章圍繞主題展開(kāi)詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-08-08PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測(cè)
hub 模塊是調(diào)用別人訓(xùn)練好的網(wǎng)絡(luò)架構(gòu)以及訓(xùn)練好的權(quán)重參數(shù),使得自己的一行代碼就可以解決問(wèn)題,方便大家進(jìn)行調(diào)用,這篇文章主要介紹了PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測(cè),需要的朋友可以參考下2023-03-03Python圖像處理之目標(biāo)物體輪廓提取的實(shí)現(xiàn)方法
目標(biāo)物體的輪廓實(shí)質(zhì)是指一系列像素點(diǎn)構(gòu)成,這些點(diǎn)構(gòu)成了一個(gè)有序的點(diǎn)集,這篇文章主要給大家介紹了關(guān)于Python圖像處理之目標(biāo)物體輪廓提取的實(shí)現(xiàn)方法,需要的朋友可以參考下2021-08-08Python編寫Windows Service服務(wù)程序
這篇文章主要為大家詳細(xì)介紹了Python編寫Windows Service服務(wù)程序,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01