一文詳解loss.item()用法和注意事項
loss.item()用法
.item()方法是,取一個元素張量里面的具體元素值并返回該值,可以將一個零維張量轉換成int型或者float型,在計算loss,accuracy時常用到。
作用:
1.item()取出張量具體位置的元素元素值
2.并且返回的是該位置元素值的高精度值
3.保持原元素類型不變;必須指定位置4.節(jié)省內存(不會計入計算圖)
import torch loss = torch.randn(2, 2) print(loss) print(loss[1,1]) print(loss[1,1].item())
輸出結果
tensor([[-2.0274, -1.5974],
[-1.4775, 1.9320]])
tensor(1.9320)
1.9319512844085693
其它:
loss = criterion(out, label) loss_sum += loss # <--- 這里
運行著就發(fā)現(xiàn)顯存炸了,觀察發(fā)現(xiàn)隨著每個batch顯存消耗在不斷增大…因為輸出的loss的數(shù)據(jù)類型是Variable。PyTorch的動態(tài)圖機制就是通過Variable來構建圖。主要是使用Variable計算的時候,會記錄下新產(chǎn)生的Variable的運算符號,在反向傳播求導的時候進行使用。如果這里直接將loss加起來,系統(tǒng)會認為這里也是計算圖的一部分,也就是說網(wǎng)絡會一直延伸變大,那么消耗的顯存也就越來越大。
正確的loss一般是這樣寫
loss_sum += loss.data[0]
其它注意事項:
使用loss += loss.detach()來獲取不需要梯度回傳的部分。
使用loss.item()直接獲得對應的python數(shù)據(jù)類型。
補充閱讀,pytorch 計算圖
Pytorch的計算圖由節(jié)點和邊組成,節(jié)點表示張量或者Function,邊表示張量和Function之間的依賴關系。
Pytorch中的計算圖是動態(tài)圖。這里的動態(tài)主要有兩重含義。
第一層含義是:計算圖的正向傳播是立即執(zhí)行的。無需等待完整的計算圖創(chuàng)建完畢,每條語句都會在計算圖中動態(tài)添加節(jié)點和邊,并立即執(zhí)行正向傳播得到計算結果。
第二層含義是:計算圖在反向傳播后立即銷毀。下次調用需要重新構建計算圖。如果在程序中使用了backward方法執(zhí)行了反向傳播,或者利用torch.autograd.grad方法計算了梯度,那么創(chuàng)建的計算圖會被立即銷毀,釋放存儲空間,下次調用需要重新創(chuàng)建。
1,計算圖的正向傳播是立即執(zhí)行的。
import torch w = torch.tensor([[3.0,1.0]],requires_grad=True) b = torch.tensor([[3.0]],requires_grad=True) X = torch.randn(10,2) Y = torch.randn(10,1) Y_hat = X@w.t() + b # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語句無關 loss = torch.mean(torch.pow(Y_hat-Y,2)) print(loss.data) print(Y_hat.data)
tensor(17.8969)
tensor([[3.2613],
[4.7322],
[4.5037],
[7.5899],
[7.0973],
[1.3287],
[6.1473],
[1.3492],
[1.3911],
[1.2150]])
2,計算圖在反向傳播后立即銷毀。
import torch w = torch.tensor([[3.0,1.0]],requires_grad=True) b = torch.tensor([[3.0]],requires_grad=True) X = torch.randn(10,2) Y = torch.randn(10,1) Y_hat = X@w.t() + b # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語句無關 loss = torch.mean(torch.pow(Y_hat-Y,2)) #計算圖在反向傳播后立即銷毀,如果需要保留計算圖, 需要設置retain_graph = True loss.backward() #loss.backward(retain_graph = True) #loss.backward() #如果再次執(zhí)行反向傳播將報錯
參考鏈接:
- https://www.zhihu.com/question/67209417/answer/344752405
- https://blog.csdn.net/cs111211/article/details/126221102
總結
到此這篇關于loss.item()用法和注意事項的文章就介紹到這了,更多相關loss.item()用法和注意事項內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
自動轉換Python代碼為HTML界面的GUI庫remi使用探究
這篇文章主要為大家介紹了自動轉換Python代碼為HTML界面的GUI庫remi使用探究,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2024-01-01Python+OpenCV實現(xiàn)圖片及視頻中選定區(qū)域顏色識別
這篇文章主要為大家詳細介紹了如何利用Python+OpenCV實現(xiàn)圖片及視頻中選定區(qū)域顏色識別功能,文中的示例代碼講解詳細,感興趣的可以了解一下2022-07-07詳解Python之數(shù)據(jù)序列化(json、pickle、shelve)
本篇文章主要介紹了Python之數(shù)據(jù)序列化,本節(jié)要介紹的就是Python內置的幾個用于進行數(shù)據(jù)序列化的模塊,有興趣的可以了解一下。2017-03-03Python PyTorch 如何獲取 MNIST 數(shù)據(jù)
這篇文章主要介紹了Python PyTorch 如何獲取 MNIST 數(shù)據(jù),通過示例代碼介紹了PyTorch 保存 MNIST 數(shù)據(jù),PyTorch 顯示 MNIST 數(shù)據(jù)的操作方法,感興趣的朋友跟隨小編一起看看吧2024-04-04詳解PyCharm使用pyQT5進行GUI開發(fā)的基本流程
本文主要介紹了PyCharm使用pyQT5進行GUI開發(fā)的基本流程,文中通過示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2021-10-10