關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
1、Autograd 求導(dǎo)機(jī)制
我們?cè)谟蒙窠?jīng)網(wǎng)絡(luò)求解PDE時(shí), 經(jīng)常要用到輸出值對(duì)輸入變量(不是Weights和Biases)求導(dǎo);
例如在訓(xùn)練WGAN-GP 時(shí), 也會(huì)用到網(wǎng)絡(luò)對(duì)輸入變量的求導(dǎo),pytorch中通過(guò) Autograd 方法進(jìn)行求導(dǎo)
其求導(dǎo)規(guī)則如下:
1.1當(dāng)x為向量,y為一標(biāo)量時(shí)
通過(guò)autograd 計(jì)算的梯度為:
1.2先假設(shè)x,y為一維向量
其對(duì)應(yīng)的jacobi(雅可比)矩陣為
grad_outputs 是一個(gè)與因變量 y 的shape 一致的向量
在給定grad_outputs 后,通過(guò)Autograd 方法 計(jì)算的梯度如下:
1.3當(dāng) x 為1維向量,Y為2維向量
給出grad_outputs與Y的shape一致
Y 與x的jacobi矩陣
則 Y 對(duì) x 的梯度:
1.4當(dāng) X ,Y均為2維向量時(shí)
???????
1.5當(dāng)XY為更高維度時(shí)
可以按照上述辦法轉(zhuǎn)化為低維度的求導(dǎo)
值得注意的是:
求導(dǎo)后的梯度shape總與自變量X保持一致對(duì)自變量求導(dǎo)的順序并不會(huì)影響結(jié)果,某自變量的梯度值會(huì)放到該自變量原來(lái)相同位置梯度是由每個(gè)自變量的導(dǎo)數(shù)值組成的向量,既有大小又有方向grad_outputs 與 因變量Y的shape一致,每一個(gè)參數(shù)相當(dāng)于對(duì)因變量中相同位置的y進(jìn)行一個(gè)加權(quán)。
2、pytorch求導(dǎo)方法
2.1在求導(dǎo)前需要對(duì)需要求導(dǎo)的自變量進(jìn)行聲明
2.2torch.autograd.gard()
grad =? autograd.grad( outputs, inputs, grad_outputs=None, retain_graph=None, ???create_graph=False, only_inputs=True, allow_unused=False )
參數(shù)解釋:
- outputs:求導(dǎo)的因變量 Y
- inputs : 求導(dǎo)自變量 X
- grad_outputs:
當(dāng)outputs為標(biāo)量時(shí),grad_outputs=None , 不需要寫,當(dāng)outputs 為向量,需要為其聲明一個(gè)與outputs相同shape的參數(shù)矩陣,該矩陣中的每個(gè)參數(shù)的作用是,對(duì)outputs中相同位置的y進(jìn)行一個(gè)加權(quán)。
不然會(huì)報(bào)錯(cuò)
autograd.grad()返回的是元組數(shù)據(jù)類型,元組的長(zhǎng)度與inputs長(zhǎng)度相同,元組中每個(gè)單位的shape與相同位置的inputs相同
retain_graph:
1、當(dāng)求完一次梯度后默認(rèn)會(huì)把圖信息釋放掉,都會(huì)free掉計(jì)算圖中所有緩存的buffers,當(dāng)要連續(xù)進(jìn)行幾次求導(dǎo)時(shí),可能會(huì)因?yàn)榍懊鎎uffers不存在而報(bào)錯(cuò)。
因?yàn)榈诙€(gè)梯度計(jì)算z對(duì)x的導(dǎo)數(shù),需要y對(duì)x的計(jì)算導(dǎo)數(shù)的緩存信息,但是在計(jì)算grad1后,保存信息被釋放,找不到了,因此報(bào)錯(cuò)。
修改如下:
2、一般計(jì)算的最后一個(gè)梯度可以不需要保存計(jì)算圖信息,這樣在計(jì)算后可以及時(shí)釋放掉占用的內(nèi)存。
3、在pytorch中連續(xù)多次調(diào)用backward()也會(huì)出現(xiàn)這樣的問(wèn)題,對(duì)中間的backwad(),需要保持計(jì)算圖信息
create_graph: 若要計(jì)算高階導(dǎo)數(shù),則必須選為True
求二階導(dǎo)方法如下:
allow_unused: 允許輸入變量不進(jìn)入計(jì)算
2.3torch.backward()
def backward( ??????????? ????????????gradient: Optional[Tensor] = None, ??????????? ????????????retain_graph: Any = None, ??????????? ????????????create_graph: Any = False, ??????????? ????????????inputs: Any = None) -> Any ??????????????????????????????????????????????? )
如果需要計(jì)算導(dǎo)數(shù),可以在tensor上直接調(diào)用.backward(),會(huì)返回該tensor所有自變量的導(dǎo)數(shù)。
通過(guò)name(自變量名).grad 可以獲得該自變量的梯度如果tensor是標(biāo)量,則backward()不需要指定任何參數(shù)如果tensor是向量,則backward()需要指定gradient一個(gè)與tensorshape相同的參數(shù)矩陣,即這里的gradient 同 grad_outputs 作用和形式完全一樣。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python利用xmltodict實(shí)現(xiàn)字典和xml互相轉(zhuǎn)換的示例代碼
xmltodict是一個(gè)Python第三方庫(kù),用于處理XML數(shù)據(jù),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2024-12-12python使用tkinter包實(shí)現(xiàn)進(jìn)度條
python中的tkinter包是一種常見(jiàn)的設(shè)計(jì)程序的GUI界面用的包,本文將使用tkinter包實(shí)現(xiàn)不同風(fēng)格的進(jìn)度條,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-11-11用Python的SimPy庫(kù)簡(jiǎn)化復(fù)雜的編程模型的介紹
這篇文章主要介紹了用Python的SimPy庫(kù)簡(jiǎn)化復(fù)雜的編程模型的介紹,本文來(lái)自于官方的開(kāi)發(fā)者技術(shù)文檔,需要的朋友可以參考下2015-04-04keras tensorflow 實(shí)現(xiàn)在python下多進(jìn)程運(yùn)行
今天小編就為大家分享一篇keras tensorflow 實(shí)現(xiàn)在python下多進(jìn)程運(yùn)行,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02python實(shí)現(xiàn)mp3文件播放的具體實(shí)現(xiàn)代碼
前段時(shí)間在搞一個(gè)基于python的語(yǔ)音助手,其中需要用到python播放音頻的功能,下面這篇文章主要給大家介紹了關(guān)于python實(shí)現(xiàn)mp3文件播放的具體實(shí)現(xiàn)代碼,需要的朋友可以參考下2023-05-05Python 流媒體播放器的實(shí)現(xiàn)(基于VLC)
這篇文章主要介紹了Python 流媒體播放器的實(shí)現(xiàn)(基于VLC),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Python聊天室?guī)Ы缑鎸?shí)現(xiàn)的示例代碼(tkinter,Mysql,Treading,socket)
這篇文章主要介紹了Python聊天室?guī)Ы缑鎸?shí)現(xiàn)的示例代碼(tkinter,Mysql,Treading,socket),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Python騷操作完美實(shí)現(xiàn)短視頻偽原創(chuàng)
剪輯的視頻上傳到某平臺(tái)碰到降權(quán)怎么辦?視頻平臺(tái)都有一套自己的鑒別算法,專門用于處理視頻的二次剪輯,本篇我們來(lái)用python做一些特殊處理2022-02-02python中threading.Semaphore和threading.Lock的具體使用
python中的多線程是一個(gè)非常重要的知識(shí)點(diǎn),本文主要介紹了python中threading.Semaphore和threading.Lock的具體使用,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2023-08-08