欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch如何自定義forward和backward函數(shù)

 更新時間:2024年10月12日 16:08:13   作者:xx_xjm  
PyTorch自動求導功能強大,但在特定情況下需要用戶自行定義backward函數(shù),通過實例解釋了保存變量、計算梯度、鏈式法則等核心概念,并展示了如何通過自定義函數(shù)集成到網(wǎng)絡中以及如何正確返回梯度,此外,還討論了多輸出情況下的梯度傳遞

pytorch自定義forward和backward函數(shù)

pytorch會自動求導,但是當遇到無法自動求導的時候,需要自己認為定義求導過程,這個時候就涉及到要定義自己的forward和backward函數(shù)。

舉例如下:

看到這里,大家應該會有很多疑問

比如:

  • 1:ctx.save_for_backward和ctx.saved_tensors的含義
  • 2:backward中各個計算函數(shù)的意義,以及backward的輸入?yún)?shù)grad_out是什么,以及grad_out包含哪些數(shù)據(jù)。

針對以上問題,我們一個個解答

  • 第一個問題:百度吧,答案很多!?。?!
  • 第二個問題:拿上面這個例子來看,我們定義了一個類似于線性層的東西,但注意這不是線性層,因為我們是直接把輸入和weight用*來做點對點的乘法的,所以這不是我們通常情況下的線性層。

但是這么看也費勁,我們寫一個網(wǎng)絡,把這個函數(shù)加到網(wǎng)絡中去,再完整的跑一遍看吧!

測試代碼

結果如下:

來進行解答

首先,backward函數(shù)的返回值,就是對應著forward里面的參數(shù)的梯度,也就是說,forward函數(shù)里面有幾個輸入?yún)?shù),那么backward函數(shù)的輸出就要有幾個!為什么是這樣?

我們首先要理解backward的輸入grad_out,為什么backward的參數(shù)就是一個,因為這是根據(jù)鏈式法則來的

比如,我們定義三個函數(shù)H(對應上面網(wǎng)絡中l(wèi)inear1),F(自定義函數(shù)xjm_inter),D(對應上面網(wǎng)絡中l(wèi)inear2),定義一個輸入x(對應上面輸入a),定義一個輸出y(對應上面輸出b):

y = D(F(H(X)))

現(xiàn)在,我們求y對x的偏導,那么:

dy/dx = dy/dD * dD/dF * dF/dH * dH/dx

好吧看到這里你可能還是不懂,為什么backward的參數(shù)就是一個grad_out??!

我們韓式以上面則個函數(shù)為例子,但是,我們現(xiàn)在不求y對x的導數(shù),我們假設F函數(shù)有一個葉子節(jié)點(或者說requires_grad=True)的參數(shù)w1,現(xiàn)在我們要求y對w1的導數(shù):

所以

dy/dw1 = dy/dD *dD/dF * dF/dw1

那么此時,F(xiàn)就是我們上面代碼中自定義的xjm_inter函數(shù),則 grad_out = dy/dD *dD/dF。

怎么理解呢,根據(jù)鏈式法則,我們呢所定義的網(wǎng)絡中的每一層都是一個單獨的函數(shù),所以函數(shù)中的變量的最終求導其實只取決于該函數(shù)本身,鏈式法則求導傳遞過來的其實永遠都知識一個值,這就是為什么backward函數(shù)的輸出只有一個。

擴展

當forward的輸出有多個的時候,那么就有多個鏈式法則,因為可以同時對x或者對w求導,此時backward的輸入可以是一個,也可以是對應forward輸出的個數(shù),如果是一個則是一個元組,包含對應的梯度!?。?/p>

那么我們的backward要實現(xiàn)什么樣的功能呢?說到這里,大家應該大概能明白了,就是實現(xiàn)當前層那的梯度計算,并進行返回,所以,這也是為什么backward的返回值要和forward的輸入值一一對應,否則會報錯。

總結

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

最新評論