pytorch如何自定義forward和backward函數(shù)
pytorch自定義forward和backward函數(shù)
pytorch會自動求導,但是當遇到無法自動求導的時候,需要自己認為定義求導過程,這個時候就涉及到要定義自己的forward和backward函數(shù)。
舉例如下:
看到這里,大家應該會有很多疑問
比如:
- 1:ctx.save_for_backward和ctx.saved_tensors的含義
- 2:backward中各個計算函數(shù)的意義,以及backward的輸入?yún)?shù)grad_out是什么,以及grad_out包含哪些數(shù)據(jù)。
針對以上問題,我們一個個解答
- 第一個問題:百度吧,答案很多!?。?!
- 第二個問題:拿上面這個例子來看,我們定義了一個類似于線性層的東西,但注意這不是線性層,因為我們是直接把輸入和weight用*來做點對點的乘法的,所以這不是我們通常情況下的線性層。
但是這么看也費勁,我們寫一個網(wǎng)絡,把這個函數(shù)加到網(wǎng)絡中去,再完整的跑一遍看吧!
測試代碼
結果如下:
來進行解答
首先,backward函數(shù)的返回值,就是對應著forward里面的參數(shù)的梯度,也就是說,forward函數(shù)里面有幾個輸入?yún)?shù),那么backward函數(shù)的輸出就要有幾個!為什么是這樣?
我們首先要理解backward的輸入grad_out,為什么backward的參數(shù)就是一個,因為這是根據(jù)鏈式法則來的
比如,我們定義三個函數(shù)H(對應上面網(wǎng)絡中l(wèi)inear1),F(自定義函數(shù)xjm_inter),D(對應上面網(wǎng)絡中l(wèi)inear2),定義一個輸入x(對應上面輸入a),定義一個輸出y(對應上面輸出b):
y = D(F(H(X)))
現(xiàn)在,我們求y對x的偏導,那么:
dy/dx = dy/dD * dD/dF * dF/dH * dH/dx
好吧看到這里你可能還是不懂,為什么backward的參數(shù)就是一個grad_out??!
我們韓式以上面則個函數(shù)為例子,但是,我們現(xiàn)在不求y對x的導數(shù),我們假設F函數(shù)有一個葉子節(jié)點(或者說requires_grad=True)的參數(shù)w1,現(xiàn)在我們要求y對w1的導數(shù):
所以
dy/dw1 = dy/dD *dD/dF * dF/dw1
那么此時,F(xiàn)就是我們上面代碼中自定義的xjm_inter函數(shù),則 grad_out = dy/dD *dD/dF。
怎么理解呢,根據(jù)鏈式法則,我們呢所定義的網(wǎng)絡中的每一層都是一個單獨的函數(shù),所以函數(shù)中的變量的最終求導其實只取決于該函數(shù)本身,鏈式法則求導傳遞過來的其實永遠都知識一個值,這就是為什么backward函數(shù)的輸出只有一個。
擴展
當forward的輸出有多個的時候,那么就有多個鏈式法則,因為可以同時對x或者對w求導,此時backward的輸入可以是一個,也可以是對應forward輸出的個數(shù),如果是一個則是一個元組,包含對應的梯度!?。?/p>
那么我們的backward要實現(xiàn)什么樣的功能呢?說到這里,大家應該大概能明白了,就是實現(xiàn)當前層那的梯度計算,并進行返回,所以,這也是為什么backward的返回值要和forward的輸入值一一對應,否則會報錯。
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python實現(xiàn)的FTP通信客戶端與服務器端功能示例
這篇文章主要介紹了Python實現(xiàn)的FTP通信客戶端與服務器端功能,涉及Python基于socket的端口監(jiān)聽、文件傳輸?shù)认嚓P操作技巧,需要的朋友可以參考下2018-03-03基于Python實現(xiàn)簡易學生信息管理系統(tǒng)
這篇文章主要為大家詳細介紹了python實現(xiàn)簡易學生信息管理系統(tǒng),文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2022-07-07Python 模擬動態(tài)產(chǎn)生字母驗證碼圖片功能
這篇文章主要介紹了Python 模擬動態(tài)產(chǎn)生字母驗證碼圖片,這里給大家介紹了pillow模塊的使用,需要的朋友可以參考下2019-12-12pytho matplotlib工具欄源碼探析一之禁用工具欄、默認工具欄和工具欄管理器三種模式的差異
這篇文章主要介紹了pytho matplotlib工具欄源碼探析一之禁用工具欄、默認工具欄和工具欄管理器三種模式的差異,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-02-02