淺談Pytorch中的自動(dòng)求導(dǎo)函數(shù)backward()所需參數(shù)的含義
正常來說backward( )函數(shù)是要傳入?yún)?shù)的,一直沒弄明白backward需要傳入的參數(shù)具體含義,但是沒關(guān)系,生命在與折騰,咱們來折騰一下,嘿嘿。
對標(biāo)量自動(dòng)求導(dǎo)
首先,如果out.backward()中的out是一個(gè)標(biāo)量的話(相當(dāng)于一個(gè)神經(jīng)網(wǎng)絡(luò)有一個(gè)樣本,這個(gè)樣本有兩個(gè)屬性,神經(jīng)網(wǎng)絡(luò)有一個(gè)輸出)那么此時(shí)我的backward函數(shù)是不需要輸入任何參數(shù)的。
import torch from torch.autograd import Variable a = Variable(torch.Tensor([2,3]),requires_grad=True) b = a + 3 c = b * 3 out = c.mean() out.backward() print('input:') print(a.data) print('output:') print(out.data.item()) print('input gradients are:') print(a.grad)
運(yùn)行結(jié)果:
不難看出,我們構(gòu)建了這樣的一個(gè)函數(shù):
所以其求導(dǎo)也很容易看出:
這是對其進(jìn)行標(biāo)量自動(dòng)求導(dǎo)的結(jié)果.
對向量自動(dòng)求導(dǎo)
如果out.backward()中的out是一個(gè)向量(或者理解成1xN的矩陣)的話,我們對向量進(jìn)行自動(dòng)求導(dǎo),看看會(huì)發(fā)生什么?
先構(gòu)建這樣的一個(gè)模型(相當(dāng)于一個(gè)神經(jīng)網(wǎng)絡(luò)有一個(gè)樣本,這個(gè)樣本有兩個(gè)屬性,神經(jīng)網(wǎng)絡(luò)有兩個(gè)輸出):
import torch from torch.autograd import Variable a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True) b = torch.zeros(1,2) b[0,0] = a[0,0] ** 2 b[0,1] = a[0,1] ** 3 out = 2 * b #其參數(shù)要傳入和out維度一樣的矩陣 out.backward(torch.FloatTensor([[1.,1.]])) print('input:') print(a.data) print('output:') print(out.data) print('input gradients are:') print(a.grad)
模型也很簡單,不難看出out求導(dǎo)出來的雅克比應(yīng)該是:
因?yàn)閍1 = 2,a2 = 4,所以上面的矩陣應(yīng)該是:
運(yùn)行的結(jié)果:
嗯,的確是8和96,但是仔細(xì)想一想,和咱們想要的雅克比矩陣的形式也不一樣啊。難道是backward自動(dòng)把0給省略了?
咱們繼續(xù)試試,這次在上一個(gè)模型的基礎(chǔ)上進(jìn)行小修改,如下:
import torch from torch.autograd import Variable a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True) b = torch.zeros(1,2) b[0,0] = a[0,0] ** 2 + a[0,1] b[0,1] = a[0,1] ** 3 + a[0,0] out = 2 * b #其參數(shù)要傳入和out維度一樣的矩陣 out.backward(torch.FloatTensor([[1.,1.]])) print('input:') print(a.data) print('output:') print(out.data) print('input gradients are:') print(a.grad)
可以看出這個(gè)模型的雅克比應(yīng)該是:
運(yùn)行一下:
等等,什么鬼?正常來說不應(yīng)該是
么?我是誰?我再哪?為什么就給我2個(gè)數(shù),而且是 8 + 2 = 10 ,96 + 2 = 98 。難道都是加的 2 ?想一想,剛才咱們backward中傳的參數(shù)是 [ [ 1 , 1 ] ],難道安裝這個(gè)關(guān)系對應(yīng)求和了?咱們換個(gè)參數(shù)來試一試,程序中只更改傳入的參數(shù)為[ [ 1 , 2 ] ]:
import torch from torch.autograd import Variable a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True) b = torch.zeros(1,2) b[0,0] = a[0,0] ** 2 + a[0,1] b[0,1] = a[0,1] ** 3 + a[0,0] out = 2 * b #其參數(shù)要傳入和out維度一樣的矩陣 out.backward(torch.FloatTensor([[1.,2.]])) print('input:') print(a.data) print('output:') print(out.data) print('input gradients are:') print(a.grad)
嗯,這回可以理解了,我們傳入的參數(shù),是對原來模型正常求導(dǎo)出來的雅克比矩陣進(jìn)行線性操作,可以把我們傳進(jìn)的參數(shù)(設(shè)為arg)看成一個(gè)列向量,那么我們得到的結(jié)果就是:
在這個(gè)題目中,我們得到的實(shí)際是:
看起來一切完美的解釋了,但是就在我剛剛打字的一刻,我意識(shí)到官方文檔中說k.backward()傳入的參數(shù)應(yīng)該和k具有相同的維度,所以如果按上述去解釋是解釋不通的。哪里出問題了呢?
仔細(xì)看了一下,原來是這樣的:在對雅克比矩陣進(jìn)行線性操作的時(shí)候,應(yīng)該把我們傳進(jìn)的參數(shù)(設(shè)為arg)看成一個(gè)行向量(不是列向量),那么我們得到的結(jié)果就是:
也就是:
這回我們就解釋的通了。
現(xiàn)在我們來輸出一下雅克比矩陣吧,為了不引起歧義,我們讓雅克比矩陣的每個(gè)數(shù)值都不一樣(一開始分析錯(cuò)了就是因?yàn)檠趴吮染仃囍杏邢嗤臄?shù)據(jù)),所以模型小改動(dòng)如下:
import torch from torch.autograd import Variable a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True) b = torch.zeros(1,2) b[0,0] = a[0,0] ** 2 + a[0,1] b[0,1] = a[0,1] ** 3 + a[0,0] * 2 out = 2 * b #其參數(shù)要傳入和out維度一樣的矩陣 out.backward(torch.FloatTensor([[1,0]]),retain_graph=True) A_temp = copy.deepcopy(a.grad) a.grad.zero_() out.backward(torch.FloatTensor([[0,1]])) B_temp = a.grad print('jacobian matrix is:') print(torch.cat( (A_temp,B_temp),0 ))
如果沒問題的話咱們的雅克比矩陣應(yīng)該是 [ [ 8 , 2 ] , [ 4 , 96 ] ]
好了,下面是見證奇跡的時(shí)刻了,不要眨眼睛奧,千萬不要眨眼睛… 3 2 1 砰…
好了,現(xiàn)在總結(jié)一下:因?yàn)榻?jīng)過了復(fù)雜的神經(jīng)網(wǎng)絡(luò)之后,out中每個(gè)數(shù)值都是由很多輸入樣本的屬性(也就是輸入數(shù)據(jù))線性或者非線性組合而成的,那么out中的每個(gè)數(shù)值和輸入數(shù)據(jù)的每個(gè)數(shù)值都有關(guān)聯(lián),也就是說【out】中的每個(gè)數(shù)都可以對【a】中每個(gè)數(shù)求導(dǎo),那么我們backward()的參數(shù)[k1,k2,k3…kn]的含義就是:
也可以理解成每個(gè)out分量對an求導(dǎo)時(shí)的權(quán)重。
對矩陣自動(dòng)求導(dǎo)
現(xiàn)在,如果out是一個(gè)矩陣呢?
下面的例子也可以理解為:相當(dāng)于一個(gè)神經(jīng)網(wǎng)絡(luò)有兩個(gè)樣本,每個(gè)樣本有兩個(gè)屬性,神經(jīng)網(wǎng)絡(luò)有兩個(gè)輸出。
import torch from torch.autograd import Variable from torch import nn a = Variable(torch.FloatTensor([[2,3],[1,2]]),requires_grad=True) w = Variable( torch.zeros(2,1),requires_grad=True ) out = torch.mm(a,w) out.backward(torch.FloatTensor([[1.],[1.]]),retain_graph=True) print("gradients are:{}".format(w.grad.data))
如果前面的例子理解了,那么這個(gè)也很好理解,backward輸入的參數(shù)k是一個(gè)2x1的矩陣,2代表的就是樣本數(shù)量,就是在前面的基礎(chǔ)上,再對每個(gè)樣本進(jìn)行加權(quán)求和。結(jié)果是:
如果有興趣,也可以拓展一下多個(gè)樣本的多分類問題,猜一下k的維度應(yīng)該是【輸入樣本的個(gè)數(shù) * 分類的個(gè)數(shù)】
好啦,糾結(jié)我好久的pytorch自動(dòng)求導(dǎo)原理算是徹底搞懂啦~~~
以上這篇淺談Pytorch中的自動(dòng)求導(dǎo)函數(shù)backward()所需參數(shù)的含義就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
pytorch神經(jīng)網(wǎng)絡(luò)從零開始實(shí)現(xiàn)多層感知機(jī)
這篇文章主要為大家介紹了pytorch神經(jīng)網(wǎng)絡(luò)從零開始實(shí)現(xiàn)多層感知機(jī)的示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10Spyder中如何設(shè)置默認(rèn)python解釋器
Spyder作為一款流行的Python IDE,支持用戶自定義Python解釋器,包括虛擬環(huán)境的設(shè)置,通過打開Spyder,選擇“Tools”->“Preferences”,在彈出窗口中選擇“Use the following Python interpreter”后,瀏覽并選擇相應(yīng)的解釋器或虛擬環(huán)境路徑2024-09-09使用Tensorflow實(shí)現(xiàn)可視化中間層和卷積層
今天小編就為大家分享一篇使用Tensorflow實(shí)現(xiàn)可視化中間層和卷積層,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01Python實(shí)現(xiàn)的簡單讀寫csv文件操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)的簡單讀寫csv文件操作,結(jié)合實(shí)例形式分析了Python使用csv模塊針對csv文件進(jìn)行讀寫操作的相關(guān)實(shí)現(xiàn)技巧與注意事項(xiàng),需要的朋友可以參考下2018-07-07Python中Parsel的兩種數(shù)據(jù)提取方式詳解
在網(wǎng)絡(luò)爬蟲的世界中,數(shù)據(jù)提取是至關(guān)重要的一環(huán),Python 提供了許多強(qiáng)大的工具,其中之一就是 parsel 庫,下面我們就來深入學(xué)習(xí)一下Parsel的兩種數(shù)據(jù)提取方式吧2023-12-12解決Django數(shù)據(jù)庫makemigrations有變化但是migrate時(shí)未變動(dòng)問題
今天小編就為大家分享一篇解決Django數(shù)據(jù)庫makemigrations有變化但是migrate時(shí)未變動(dòng)的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05python從sqlite讀取并顯示數(shù)據(jù)的方法
這篇文章主要介紹了python從sqlite讀取并顯示數(shù)據(jù)的方法,涉及Python操作SQLite數(shù)據(jù)庫的讀取及顯示相關(guān)技巧,需要的朋友可以參考下2015-05-05python3.5實(shí)現(xiàn)socket通訊示例(TCP)
本篇文章主要介紹了python3.5實(shí)現(xiàn)socket通訊示例(TCP),小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-02-02Python?pickle模塊實(shí)現(xiàn)Python對象持久化存儲(chǔ)
這篇文章主要介紹了Python?pickle模塊實(shí)現(xiàn)Python對象持久化存儲(chǔ),pickle?是?python?語言的一個(gè)標(biāo)準(zhǔn)模塊,和python安裝時(shí)共同安裝好的一個(gè)模塊。下文基于pickle模塊展開實(shí)現(xiàn)Python對象持久化存儲(chǔ)的詳細(xì)內(nèi)容,需要的朋友可以參考一下2022-05-05