欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

一文詳解loss.item()用法和注意事項

 更新時間:2023年06月14日 09:43:05   作者:德國Viviane  
loss.item()是PyTorch中的一種方法,用于計算損失函數(shù)的值,下面這篇文章主要給大家介紹了關于loss.item()用法和注意事項的相關資料,需要的朋友可以參考下

loss.item()用法

.item()方法是,取一個元素張量里面的具體元素值并返回該值,可以將一個零維張量轉換成int型或者float型,在計算loss,accuracy時常用到。

作用:

1.item()取出張量具體位置的元素元素值
2.并且返回的是該位置元素值的高精度值
3.保持原元素類型不變;必須指定位置

4.節(jié)省內存(不會計入計算圖)

import torch
 
loss = torch.randn(2, 2)
 
print(loss)
print(loss[1,1])
print(loss[1,1].item())

輸出結果

tensor([[-2.0274, -1.5974],
        [-1.4775,  1.9320]])
tensor(1.9320)
1.9319512844085693

其它:

loss = criterion(out, label)
    loss_sum += loss     # <--- 這里

運行著就發(fā)現(xiàn)顯存炸了,觀察發(fā)現(xiàn)隨著每個batch顯存消耗在不斷增大…因為輸出的loss的數(shù)據(jù)類型是Variable。PyTorch的動態(tài)圖機制就是通過Variable來構建圖。主要是使用Variable計算的時候,會記錄下新產(chǎn)生的Variable的運算符號,在反向傳播求導的時候進行使用。如果這里直接將loss加起來,系統(tǒng)會認為這里也是計算圖的一部分,也就是說網(wǎng)絡會一直延伸變大,那么消耗的顯存也就越來越大。

正確的loss一般是這樣寫 

loss_sum += loss.data[0]

其它注意事項:

使用loss += loss.detach()來獲取不需要梯度回傳的部分。

使用loss.item()直接獲得對應的python數(shù)據(jù)類型

補充閱讀,pytorch 計算圖

Pytorch的計算圖由節(jié)點和邊組成,節(jié)點表示張量或者Function,邊表示張量和Function之間的依賴關系。

Pytorch中的計算圖是動態(tài)圖。這里的動態(tài)主要有兩重含義。

第一層含義是:計算圖的正向傳播是立即執(zhí)行的。無需等待完整的計算圖創(chuàng)建完畢,每條語句都會在計算圖中動態(tài)添加節(jié)點和邊,并立即執(zhí)行正向傳播得到計算結果。

第二層含義是:計算圖在反向傳播后立即銷毀。下次調用需要重新構建計算圖。如果在程序中使用了backward方法執(zhí)行了反向傳播,或者利用torch.autograd.grad方法計算了梯度,那么創(chuàng)建的計算圖會被立即銷毀,釋放存儲空間,下次調用需要重新創(chuàng)建。

1,計算圖的正向傳播是立即執(zhí)行的。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語句無關
loss = torch.mean(torch.pow(Y_hat-Y,2))
 
print(loss.data)
print(Y_hat.data)

tensor(17.8969)
tensor([[3.2613],
        [4.7322],
        [4.5037],
        [7.5899],
        [7.0973],
        [1.3287],
        [6.1473],
        [1.3492],
        [1.3911],
        [1.2150]])

2,計算圖在反向傳播后立即銷毀。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語句無關
loss = torch.mean(torch.pow(Y_hat-Y,2))
#計算圖在反向傳播后立即銷毀,如果需要保留計算圖, 需要設置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) 
#loss.backward() #如果再次執(zhí)行反向傳播將報錯

參考鏈接:

  • https://www.zhihu.com/question/67209417/answer/344752405
  • https://blog.csdn.net/cs111211/article/details/126221102

總結 

到此這篇關于loss.item()用法和注意事項的文章就介紹到這了,更多相關loss.item()用法和注意事項內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

最新評論