PyTorch中torch.no_grad()用法舉例詳解
前言
torch.no_grad() 是 PyTorch 中的一個上下文管理器,用于在上下文中臨時(shí)禁用自動梯度計(jì)算。它在模型評估或推理階段非常有用,因?yàn)樵谶@些階段,我們通常不需要計(jì)算梯度。禁用梯度計(jì)算可以減少內(nèi)存消耗,并加快計(jì)算速度。
基本概念
在 PyTorch 中,每次對 requires_grad=True 的張量進(jìn)行操作時(shí),PyTorch 會構(gòu)建一個計(jì)算圖(computation graph),用于計(jì)算反向傳播的梯度。這對訓(xùn)練模型是必要的,但在評估或推理時(shí)不需要。因此,我們可以使用 torch.no_grad() 來臨時(shí)禁用這些計(jì)算圖的構(gòu)建和梯度計(jì)算。
用法
torch.no_grad() 的使用非常簡單。只需要將不需要梯度計(jì)算的代碼塊放在 with torch.no_grad(): 下即可。
示例代碼
以下是一個使用 torch.no_grad() 的示例:
import torch
# 創(chuàng)建一個張量,并設(shè)置 requires_grad=True 以便記錄梯度
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 在 torch.no_grad() 上下文中禁用梯度計(jì)算
with torch.no_grad():
y = x + 2
print(y)
# 此時(shí),x 的 requires_grad 屬性仍然為 True,但 y 的 requires_grad 屬性為 False
print("x 的 requires_grad:", x.requires_grad)
print("y 的 requires_grad:", y.requires_grad)詳細(xì)解釋
創(chuàng)建張量并設(shè)置 requires_grad=True:
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
創(chuàng)建一個包含三個元素的張量 x。
設(shè)置 requires_grad=True,告訴 PyTorch 需要為該張量記錄梯度。
禁用梯度計(jì)算:
with torch.no_grad():
y = x + 2
print(y)進(jìn)入 torch.no_grad() 上下文,臨時(shí)禁用梯度計(jì)算。
在上下文中,對 x 進(jìn)行加法操作,得到新的張量 y。
打印 y,此時(shí) y 的 requires_grad 屬性為 False。
查看 requires_grad 屬性:
print("x 的 requires_grad:", x.requires_grad)
print("y 的 requires_grad:", y.requires_grad)打印 x 的 requires_grad 屬性,仍然為 True。
打印 y 的 requires_grad 屬性,已被禁用為 False。
使用場景
模型評估
在評估模型性能時(shí),不需要計(jì)算梯度。使用 torch.no_grad() 可以提高評估速度和減少內(nèi)存消耗。
model.eval() # 切換到評估模式
with torch.no_grad():
for data in validation_loader:
outputs = model(data)
# 計(jì)算評估指標(biāo)模型推理
在部署和推理階段,只需要前向傳播,不需要反向傳播,因此可以使用 torch.no_grad()。
with torch.no_grad():
outputs = model(inputs)
predicted = torch.argmax(outputs, dim=1)初始化權(quán)重或其他不需要梯度的操作
在某些初始化或操作中,不需要梯度計(jì)算。
with torch.no_grad():
model.weight.fill_(1.0) # 直接修改權(quán)重小結(jié)
torch.no_grad() 是一個用于禁用梯度計(jì)算的上下文管理器,適用于模型評估、推理等不需要梯度計(jì)算的場景。使用 torch.no_grad() 可以顯著減少內(nèi)存使用和加速計(jì)算。通過理解和合理使用 torch.no_grad(),可以使得模型評估和推理更加高效和穩(wěn)定。
額外注意事項(xiàng)
訓(xùn)練模式與評估模式:
在使用 torch.no_grad() 時(shí),通常還會將模型設(shè)置為評估模式(model.eval()),以確保某些層(如 dropout 和 batch normalization)在推理時(shí)的行為與訓(xùn)練時(shí)不同。
嵌套使用:
torch.no_grad() 可以嵌套使用,內(nèi)層的 torch.no_grad() 仍然會禁用梯度計(jì)算。
with torch.no_grad():
with torch.no_grad():
y = x + 2
print(y)恢復(fù)梯度計(jì)算:
在 torch.no_grad() 上下文管理器退出后,梯度計(jì)算會自動恢復(fù),不需要額外操作。
with torch.no_grad():
y = x + 2
print(y)
# 這里梯度計(jì)算恢復(fù)
z = x * 2
print(z.requires_grad) # True通過合理使用 torch.no_grad(),可以在不需要梯度計(jì)算的場景中提升性能并節(jié)省資源。
總結(jié)
到此這篇關(guān)于PyTorch中torch.no_grad()用法舉例詳解的文章就介紹到這了,更多相關(guān)PyTorch torch.no_grad()詳解內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中4種實(shí)現(xiàn)數(shù)值的交換方式
這篇文章主要介紹了Python中4種實(shí)現(xiàn)數(shù)值的交換方式,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-08-08
PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測
hub 模塊是調(diào)用別人訓(xùn)練好的網(wǎng)絡(luò)架構(gòu)以及訓(xùn)練好的權(quán)重參數(shù),使得自己的一行代碼就可以解決問題,方便大家進(jìn)行調(diào)用,這篇文章主要介紹了PyTorch?之?強(qiáng)大的?hub?模塊和搭建神經(jīng)網(wǎng)絡(luò)進(jìn)行氣溫預(yù)測,需要的朋友可以參考下2023-03-03
Python圖像處理之目標(biāo)物體輪廓提取的實(shí)現(xiàn)方法
目標(biāo)物體的輪廓實(shí)質(zhì)是指一系列像素點(diǎn)構(gòu)成,這些點(diǎn)構(gòu)成了一個有序的點(diǎn)集,這篇文章主要給大家介紹了關(guān)于Python圖像處理之目標(biāo)物體輪廓提取的實(shí)現(xiàn)方法,需要的朋友可以參考下2021-08-08
Python編寫Windows Service服務(wù)程序
這篇文章主要為大家詳細(xì)介紹了Python編寫Windows Service服務(wù)程序,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01

