pytorch如何自定義forward和backward函數(shù)
pytorch自定義forward和backward函數(shù)
pytorch會自動求導(dǎo),但是當(dāng)遇到無法自動求導(dǎo)的時候,需要自己認(rèn)為定義求導(dǎo)過程,這個時候就涉及到要定義自己的forward和backward函數(shù)。
舉例如下:
看到這里,大家應(yīng)該會有很多疑問
比如:
- 1:ctx.save_for_backward和ctx.saved_tensors的含義
- 2:backward中各個計算函數(shù)的意義,以及backward的輸入?yún)?shù)grad_out是什么,以及grad_out包含哪些數(shù)據(jù)。
針對以上問題,我們一個個解答
- 第一個問題:百度吧,答案很多?。。?!
- 第二個問題:拿上面這個例子來看,我們定義了一個類似于線性層的東西,但注意這不是線性層,因為我們是直接把輸入和weight用*來做點對點的乘法的,所以這不是我們通常情況下的線性層。
但是這么看也費勁,我們寫一個網(wǎng)絡(luò),把這個函數(shù)加到網(wǎng)絡(luò)中去,再完整的跑一遍看吧!
測試代碼
結(jié)果如下:
來進(jìn)行解答
首先,backward函數(shù)的返回值,就是對應(yīng)著forward里面的參數(shù)的梯度,也就是說,forward函數(shù)里面有幾個輸入?yún)?shù),那么backward函數(shù)的輸出就要有幾個!為什么是這樣?
我們首先要理解backward的輸入grad_out,為什么backward的參數(shù)就是一個,因為這是根據(jù)鏈?zhǔn)椒▌t來的
比如,我們定義三個函數(shù)H(對應(yīng)上面網(wǎng)絡(luò)中l(wèi)inear1),F(自定義函數(shù)xjm_inter),D(對應(yīng)上面網(wǎng)絡(luò)中l(wèi)inear2),定義一個輸入x(對應(yīng)上面輸入a),定義一個輸出y(對應(yīng)上面輸出b):
y = D(F(H(X)))
現(xiàn)在,我們求y對x的偏導(dǎo),那么:
dy/dx = dy/dD * dD/dF * dF/dH * dH/dx
好吧看到這里你可能還是不懂,為什么backward的參數(shù)就是一個grad_out??!
我們韓式以上面則個函數(shù)為例子,但是,我們現(xiàn)在不求y對x的導(dǎo)數(shù),我們假設(shè)F函數(shù)有一個葉子節(jié)點(或者說requires_grad=True)的參數(shù)w1,現(xiàn)在我們要求y對w1的導(dǎo)數(shù):
所以
dy/dw1 = dy/dD *dD/dF * dF/dw1
那么此時,F(xiàn)就是我們上面代碼中自定義的xjm_inter函數(shù),則 grad_out = dy/dD *dD/dF。
怎么理解呢,根據(jù)鏈?zhǔn)椒▌t,我們呢所定義的網(wǎng)絡(luò)中的每一層都是一個單獨的函數(shù),所以函數(shù)中的變量的最終求導(dǎo)其實只取決于該函數(shù)本身,鏈?zhǔn)椒▌t求導(dǎo)傳遞過來的其實永遠(yuǎn)都知識一個值,這就是為什么backward函數(shù)的輸出只有一個。
擴(kuò)展
當(dāng)forward的輸出有多個的時候,那么就有多個鏈?zhǔn)椒▌t,因為可以同時對x或者對w求導(dǎo),此時backward的輸入可以是一個,也可以是對應(yīng)forward輸出的個數(shù),如果是一個則是一個元組,包含對應(yīng)的梯度?。。?/p>
那么我們的backward要實現(xiàn)什么樣的功能呢?說到這里,大家應(yīng)該大概能明白了,就是實現(xiàn)當(dāng)前層那的梯度計算,并進(jìn)行返回,所以,這也是為什么backward的返回值要和forward的輸入值一一對應(yīng),否則會報錯。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python中有哪些關(guān)鍵字及關(guān)鍵字的用法
這篇文章主要介紹了Python中有哪些關(guān)鍵字及關(guān)鍵字的用法,分享python中常用的關(guān)鍵字,本文結(jié)合示例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-02-02Python實現(xiàn)的FTP通信客戶端與服務(wù)器端功能示例
這篇文章主要介紹了Python實現(xiàn)的FTP通信客戶端與服務(wù)器端功能,涉及Python基于socket的端口監(jiān)聽、文件傳輸?shù)认嚓P(guān)操作技巧,需要的朋友可以參考下2018-03-03基于Python實現(xiàn)簡易學(xué)生信息管理系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了python實現(xiàn)簡易學(xué)生信息管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2022-07-07Python 模擬動態(tài)產(chǎn)生字母驗證碼圖片功能
這篇文章主要介紹了Python 模擬動態(tài)產(chǎn)生字母驗證碼圖片,這里給大家介紹了pillow模塊的使用,需要的朋友可以參考下2019-12-12pytho matplotlib工具欄源碼探析一之禁用工具欄、默認(rèn)工具欄和工具欄管理器三種模式的差異
這篇文章主要介紹了pytho matplotlib工具欄源碼探析一之禁用工具欄、默認(rèn)工具欄和工具欄管理器三種模式的差異,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-02-02python多進(jìn)程下的生產(chǎn)者和消費者模型
這篇文章主要介紹了python多進(jìn)程下的生產(chǎn)者和消費者模型,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-05-05Python?NumPy教程之?dāng)?shù)組的基本操作詳解
Numpy?中的數(shù)組是一個元素表(通常是數(shù)字),所有元素類型相同,由正整數(shù)元組索引。本文將通過一些示例詳細(xì)講一下NumPy中數(shù)組的一些基本操作,需要的可以參考一下2022-08-08