欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別解析

 更新時(shí)間:2023年04月07日 08:29:00   作者:小瓶蓋的豬豬俠  
這篇文章主要介紹了PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別解析,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下

PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別

以 a.data, a.detach() 為例:
兩種方法均會(huì)返回和a相同的tensor,且與原tensor a 共享數(shù)據(jù),一方改變,則另一方也改變。

所起的作用均是將變量tensor從原有的計(jì)算圖中分離出來,分離所得tensor的requires_grad = False。

不同點(diǎn):

data是一個(gè)屬性,.detach()是一個(gè)方法;data是不安全的,.detach()是安全的;

>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid()
>>> c = out.data
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out                   #  out的數(shù)值被c.zero_()修改
tensor([ 0., 0., 0.])

>>> out.sum().backward()  #  反向傳播
>>> a.grad                #  這個(gè)結(jié)果很嚴(yán)重的錯(cuò)誤,因?yàn)閛ut已經(jīng)改變了
tensor([ 0., 0., 0.])

為什么.data是不安全的?

這是因?yàn)?,?dāng)我們修改分離后的tensor,從而導(dǎo)致原tensora發(fā)生改變。PyTorch的自動(dòng)求導(dǎo)Autograd是無法捕捉到這種變化的,會(huì)依然按照求導(dǎo)規(guī)則進(jìn)行求導(dǎo),導(dǎo)致計(jì)算出錯(cuò)誤的導(dǎo)數(shù)值。

其風(fēng)險(xiǎn)性在于,如果我在某一處修改了某一個(gè)變量,求導(dǎo)的時(shí)候也無法得知這一修改,可能會(huì)在不知情的情況下計(jì)算出錯(cuò)誤的導(dǎo)數(shù)值。

>>> a = torch.tensor([1,2,3.], requires_grad =True)
>>> out = a.sigmoid()
>>> c = out.detach()
>>> c.zero_()
tensor([ 0., 0., 0.])

>>> out                   #  out的值被c.zero_()修改 !!
tensor([ 0., 0., 0.])

>>> out.sum().backward()  #  需要原來out得值,但是已經(jīng)被c.zero_()覆蓋了,結(jié)果報(bào)錯(cuò)
RuntimeError: one of the variables needed for gradient
computation has been modified by an

那么.detach()為什么是安全的?

使用.detach()的好處在于,若是出現(xiàn)上述情況,Autograd可以檢測(cè)出某一處變量已經(jīng)發(fā)生了改變,進(jìn)而以如下形式報(bào)錯(cuò),從而避免了錯(cuò)誤的求導(dǎo)。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

從以上可以看出,是在前向傳播的過程中使用就地操作(In-place operation)導(dǎo)致了這一問題,那么就地操作是什么呢?

補(bǔ)充:pytorch中的detach()函數(shù)的作用

detach()

官方文檔中,對(duì)這個(gè)方法是這么介紹的。

  • 返回一個(gè)新的從當(dāng)前圖中分離的 Variable。
  • 返回的 Variable 永遠(yuǎn)不會(huì)需要梯度 如果 被 detach
  • 的Variable volatile=True, 那么 detach 出來的 volatile 也為 True
  • 還有一個(gè)注意事項(xiàng),即:返回的 Variable 和 被 detach 的Variable 指向同一個(gè) tensor
import torch
from torch.nn import init
from torch.autograd import Variable
t1 = torch.FloatTensor([1., 2.])
v1 = Variable(t1)
t2 = torch.FloatTensor([2., 3.])
v2 = Variable(t2)
v3 = v1 + v2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached)    # v3 中tensor 的值也會(huì)改變

能用來干啥

可以對(duì)部分網(wǎng)絡(luò)求梯度。

如果我們有兩個(gè)網(wǎng)絡(luò) , 兩個(gè)關(guān)系是這樣的 現(xiàn)在我們想用 來為B網(wǎng)絡(luò)的參數(shù)來求梯度,但是又不想求A網(wǎng)絡(luò)參數(shù)的梯度。我們可以這樣:

# y=A(x), z=B(y) 求B中參數(shù)的梯度,不求A中參數(shù)的梯度
y = A(x)
z = B(y.detach())
z.backward()

到此這篇關(guān)于PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別的文章就介紹到這了,更多相關(guān)PyTorch tensor.detach() 和 tensor.data內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評(píng)論