pytorch的backward()的底層實現(xiàn)邏輯詳解

自動微分是一種計算張量(tensors)的梯度(gradients)的技術,它在深度學習中非常有用。自動微分的基本思想是:
- 自動微分會記錄數(shù)據(jù)(張量)和所有執(zhí)行的操作(以及產(chǎn)生的新張量)在一個由函數(shù)(Function)對象組成的有向無環(huán)圖(DAG)中。在這個圖中,葉子節(jié)點是輸入張量,根節(jié)點是輸出張量。通過從根節(jié)點到葉子節(jié)點追蹤這個圖,可以使用鏈式法則(chain rule)自動地計算梯度。
- 在前向傳播(forward pass)中,自動微分同時做兩件事:
- 運行請求的操作來計算一個結果張量,以及
- 在 DAG 中保留操作的梯度函數(shù)。
- 在 DAG 中保留操作的梯度函數(shù),這就是說,當你給自動微分一個張量和一個操作,它不僅會計算出結果張量,還會記住這個操作的梯度函數(shù),也就是這個操作對輸入張量的導數(shù)。例如,如果你給自動微分一個張量 x = [1, 2, 3] 和一個操作 y = x + 1,它不僅會計算出 y = [2, 3, 4],還會記住這個操作的梯度函數(shù)是 dy/dx = 1,也就是說,y 對 x 的導數(shù)是 1。這樣,當你需要計算梯度時,自動微分就可以根據(jù)這個梯度函數(shù)來計算出結果張量對輸入張量的梯度。
- 在PyTorch中,DAG是動態(tài)的。需要注意的一點是,圖是從頭開始重新創(chuàng)建的;在每個
.backward()調用之后,autograd開始填充一個新的圖。 - 后向傳播開始于當在 DAG 的根節(jié)點上調用 .backward() 方法。這個方法會觸發(fā)自動微分開始計算梯度。
- 自動微分會從每個 .grad_fn 中計算梯度,這個 .grad_fn 是一個函數(shù)對象,它保存了操作的梯度函數(shù)。例如,如果一個操作是 y = x + 1,那么它的 .grad_fn 就是 dy/dx = 1。
- 自動微分會將計算出的梯度累加到相應張量的 .grad 屬性中,這個 .grad 屬性是一個張量,它保存了結果張量對輸入張量的梯度。例如,如果一個結果張量是 y = [2, 3, 4],那么它的 .grad 屬性就是 [1, 1, 1],表示 y 對 x 的梯度是 1。
- 使用鏈式法則(chain rule),自動微分會一直向后傳播,直到到達葉子張量。鏈式法則是一種數(shù)學公式,它可以將復合函數(shù)的梯度分解為簡單函數(shù)的梯度的乘積。例如,如果一個復合函數(shù)是 z = f(g(x)),那么它的梯度是 dz/dx = dz/dg * dg/dx。
import torch
import torch.nn as nn
M = nn.Linear(2, 2) # neural network module
M.eval() # set M to evaluation mode
with torch.no_grad(): # disable gradient computation
for param in M.parameters(): # loop over all parameters
param.fill_(1) # fill the parameter with 1
M.requires_grad_(False)
a = torch.tensor([1., 2.], requires_grad=True) # leaf node
b = torch.tensor([13., 32.], requires_grad=True) # leaf node
c = M(a) # non-leaf node
c2 = M(b) # non-leaf node
d = c * 2 # non-leaf node
d.sum().backward() # compute gradients
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(M.weight.grad) # None構建計算圖:當我們調用backward()方法時,PyTorch會自動構建從葉子節(jié)點a到損失值d.sum()的計算圖,這是一個有向無環(huán)圖,表示了各個張量之間的運算關系。計算圖中還包含了兩個中間變量c和d,它們是由a經(jīng)過M模型的前向傳播得到的。計算圖的作用是記錄反向傳播的路徑,以便于計算梯度。 計算梯度:在計算圖中,每個張量都有一個屬性grad,用于存儲它的梯度值。當我們調用backward()方法時,PyTorch會沿著計算圖按照鏈式法則計算并填充每個張量的grad屬性。由于我們只對葉子節(jié)點a的梯度感興趣,所以只有a的grad屬性會被計算出來,而中間變量c和d的grad屬性會被忽略。a的grad屬性的值是損失值d.sum()對a的偏導數(shù),表示了a的變化對損失值的影響。
到此這篇關于pytorch的backward()的底層實現(xiàn)邏輯的文章就介紹到這了,更多相關pytorch backward()內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
使用python如何提取JSON數(shù)據(jù)指定內容
這篇文章主要介紹了使用python如何提取JSON數(shù)據(jù)指定內容,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07
Python+Selenium+phantomjs實現(xiàn)網(wǎng)頁模擬登錄和截圖功能(windows環(huán)境)
Python是一種跨平臺的計算機程序設計語言,它可以運行在Windows、Mac和各種Linux/Unix系統(tǒng)上。這篇文章主要介紹了Python+Selenium+phantomjs實現(xiàn)網(wǎng)頁模擬登錄和截圖功能,需要的朋友可以參考下2019-12-12
Python中json.loads和json.dumps方法中英雙語詳解
在Python中json.loads和json.dumps是處理JSON數(shù)據(jù)的重要方法,json.loads用于將JSON字符串解析為Python對象,而json.dumps用于將Python對象序列化為JSON字符串,文中通過代碼介紹的非常詳細,需要的朋友可以參考下2025-01-01
Python OpenCV 調用攝像頭并截圖保存功能的實現(xiàn)代碼
這篇文章主要介紹了Python OpenCV 調用攝像頭并截圖保存功能,本文通過兩段實例代碼給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下2019-07-07
Python獲取網(wǎng)頁數(shù)據(jù)詳解流程
讀萬卷書不如行萬里路,只學書上的理論是遠遠不夠的,只有在實戰(zhàn)中才能獲得能力的提升,本篇文章手把手帶你用Python來獲取網(wǎng)頁的數(shù)據(jù),主要應用了Requests庫,大家可以在過程中查缺補漏,提升水平2021-10-10

