PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別解析
PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別
以 a.data, a.detach() 為例:
兩種方法均會(huì)返回和a相同的tensor,且與原tensor a 共享數(shù)據(jù),一方改變,則另一方也改變。
所起的作用均是將變量tensor從原有的計(jì)算圖中分離出來,分離所得tensor的requires_grad = False。
不同點(diǎn):
data是一個(gè)屬性,.detach()是一個(gè)方法;data是不安全的,.detach()是安全的;
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.data >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的數(shù)值被c.zero_()修改 tensor([ 0., 0., 0.]) >>> out.sum().backward() # 反向傳播 >>> a.grad # 這個(gè)結(jié)果很嚴(yán)重的錯(cuò)誤,因?yàn)閛ut已經(jīng)改變了 tensor([ 0., 0., 0.])
為什么.data是不安全的?
這是因?yàn)?,?dāng)我們修改分離后的tensor,從而導(dǎo)致原tensora發(fā)生改變。PyTorch的自動(dòng)求導(dǎo)Autograd是無法捕捉到這種變化的,會(huì)依然按照求導(dǎo)規(guī)則進(jìn)行求導(dǎo),導(dǎo)致計(jì)算出錯(cuò)誤的導(dǎo)數(shù)值。
其風(fēng)險(xiǎn)性在于,如果我在某一處修改了某一個(gè)變量,求導(dǎo)的時(shí)候也無法得知這一修改,可能會(huì)在不知情的情況下計(jì)算出錯(cuò)誤的導(dǎo)數(shù)值。
>>> a = torch.tensor([1,2,3.], requires_grad =True) >>> out = a.sigmoid() >>> c = out.detach() >>> c.zero_() tensor([ 0., 0., 0.]) >>> out # out的值被c.zero_()修改 !! tensor([ 0., 0., 0.]) >>> out.sum().backward() # 需要原來out得值,但是已經(jīng)被c.zero_()覆蓋了,結(jié)果報(bào)錯(cuò) RuntimeError: one of the variables needed for gradient computation has been modified by an
那么.detach()為什么是安全的?
使用.detach()的好處在于,若是出現(xiàn)上述情況,Autograd可以檢測(cè)出某一處變量已經(jīng)發(fā)生了改變,進(jìn)而以如下形式報(bào)錯(cuò),從而避免了錯(cuò)誤的求導(dǎo)。
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
從以上可以看出,是在前向傳播的過程中使用就地操作(In-place operation)導(dǎo)致了這一問題,那么就地操作是什么呢?
補(bǔ)充:pytorch中的detach()函數(shù)的作用
detach()
官方文檔中,對(duì)這個(gè)方法是這么介紹的。
- 返回一個(gè)新的從當(dāng)前圖中分離的 Variable。
- 返回的 Variable 永遠(yuǎn)不會(huì)需要梯度 如果 被 detach
- 的Variable volatile=True, 那么 detach 出來的 volatile 也為 True
- 還有一個(gè)注意事項(xiàng),即:返回的 Variable 和 被 detach 的Variable 指向同一個(gè) tensor
import torch from torch.nn import init from torch.autograd import Variable t1 = torch.FloatTensor([1., 2.]) v1 = Variable(t1) t2 = torch.FloatTensor([2., 3.]) v2 = Variable(t2) v3 = v1 + v2 v3_detached = v3.detach() v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值 print(v3, v3_detached) # v3 中tensor 的值也會(huì)改變
能用來干啥
可以對(duì)部分網(wǎng)絡(luò)求梯度。
如果我們有兩個(gè)網(wǎng)絡(luò) , 兩個(gè)關(guān)系是這樣的 現(xiàn)在我們想用 來為B網(wǎng)絡(luò)的參數(shù)來求梯度,但是又不想求A網(wǎng)絡(luò)參數(shù)的梯度。我們可以這樣:
# y=A(x), z=B(y) 求B中參數(shù)的梯度,不求A中參數(shù)的梯度 y = A(x) z = B(y.detach()) z.backward()
到此這篇關(guān)于PyTorch中 tensor.detach() 和 tensor.data 的區(qū)別的文章就介紹到這了,更多相關(guān)PyTorch tensor.detach() 和 tensor.data內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 把序列轉(zhuǎn)換為元組的函數(shù)tuple方法
今天小編就為大家分享一篇Python 把序列轉(zhuǎn)換為元組的函數(shù)tuple方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-06-06Python實(shí)現(xiàn)克里金插值法的過程詳解
克里金算法提供的半變異函數(shù)模型有高斯、線形、球形、阻尼正弦和指數(shù)模型等,在對(duì)氣象要素場(chǎng)插值時(shí)球形模擬比較好。本文將用Python實(shí)現(xiàn)克里金插值法,感興趣的可以了解一下2022-11-11如何將python中的List轉(zhuǎn)化成dictionary
這篇文章主要介紹在python中如何將list轉(zhuǎn)化成dictionary,通過提出兩個(gè)問題來告訴大家如何解決,有需要的可以參考借鑒。2016-08-08python使用rstrip函數(shù)刪除字符串末位字符
rstrip函數(shù)用于刪除字符串末位指定字符,默認(rèn)為空白符,這篇文章主要介紹了python使用rstrip函數(shù)刪除字符串末位字符的方法,需要的朋友可以參考下2023-04-04利用numpy實(shí)現(xiàn)一、二維數(shù)組的拼接簡(jiǎn)單代碼示例
這篇文章主要介紹了利用numpy實(shí)現(xiàn)一、二維數(shù)組的拼接簡(jiǎn)單代碼示例,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12python命令行參數(shù)解析OptionParser類用法實(shí)例
這篇文章主要介紹了python命令行參數(shù)解析OptionParser類用法實(shí)例,需要的朋友可以參考下2014-10-10