欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

一文詳解loss.item()用法和注意事項(xiàng)

 更新時(shí)間:2023年06月14日 09:43:05   作者:德國(guó)Viviane  
loss.item()是PyTorch中的一種方法,用于計(jì)算損失函數(shù)的值,下面這篇文章主要給大家介紹了關(guān)于loss.item()用法和注意事項(xiàng)的相關(guān)資料,需要的朋友可以參考下

loss.item()用法

.item()方法是,取一個(gè)元素張量里面的具體元素值并返回該值,可以將一個(gè)零維張量轉(zhuǎn)換成int型或者float型,在計(jì)算loss,accuracy時(shí)常用到。

作用:

1.item()取出張量具體位置的元素元素值
2.并且返回的是該位置元素值的高精度值
3.保持原元素類型不變;必須指定位置

4.節(jié)省內(nèi)存(不會(huì)計(jì)入計(jì)算圖)

import torch
 
loss = torch.randn(2, 2)
 
print(loss)
print(loss[1,1])
print(loss[1,1].item())

輸出結(jié)果

tensor([[-2.0274, -1.5974],
        [-1.4775,  1.9320]])
tensor(1.9320)
1.9319512844085693

其它:

loss = criterion(out, label)
    loss_sum += loss     # <--- 這里

運(yùn)行著就發(fā)現(xiàn)顯存炸了,觀察發(fā)現(xiàn)隨著每個(gè)batch顯存消耗在不斷增大…因?yàn)檩敵龅膌oss的數(shù)據(jù)類型是Variable。PyTorch的動(dòng)態(tài)圖機(jī)制就是通過(guò)Variable來(lái)構(gòu)建圖。主要是使用Variable計(jì)算的時(shí)候,會(huì)記錄下新產(chǎn)生的Variable的運(yùn)算符號(hào),在反向傳播求導(dǎo)的時(shí)候進(jìn)行使用。如果這里直接將loss加起來(lái),系統(tǒng)會(huì)認(rèn)為這里也是計(jì)算圖的一部分,也就是說(shuō)網(wǎng)絡(luò)會(huì)一直延伸變大,那么消耗的顯存也就越來(lái)越大。

正確的loss一般是這樣寫(xiě) 

loss_sum += loss.data[0]

其它注意事項(xiàng):

使用loss += loss.detach()來(lái)獲取不需要梯度回傳的部分。

使用loss.item()直接獲得對(duì)應(yīng)的python數(shù)據(jù)類型

補(bǔ)充閱讀,pytorch 計(jì)算圖

Pytorch的計(jì)算圖由節(jié)點(diǎn)和邊組成,節(jié)點(diǎn)表示張量或者Function,邊表示張量和Function之間的依賴關(guān)系。

Pytorch中的計(jì)算圖是動(dòng)態(tài)圖。這里的動(dòng)態(tài)主要有兩重含義。

第一層含義是:計(jì)算圖的正向傳播是立即執(zhí)行的。無(wú)需等待完整的計(jì)算圖創(chuàng)建完畢,每條語(yǔ)句都會(huì)在計(jì)算圖中動(dòng)態(tài)添加節(jié)點(diǎn)和邊,并立即執(zhí)行正向傳播得到計(jì)算結(jié)果。

第二層含義是:計(jì)算圖在反向傳播后立即銷(xiāo)毀。下次調(diào)用需要重新構(gòu)建計(jì)算圖。如果在程序中使用了backward方法執(zhí)行了反向傳播,或者利用torch.autograd.grad方法計(jì)算了梯度,那么創(chuàng)建的計(jì)算圖會(huì)被立即銷(xiāo)毀,釋放存儲(chǔ)空間,下次調(diào)用需要重新創(chuàng)建。

1,計(jì)算圖的正向傳播是立即執(zhí)行的。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語(yǔ)句無(wú)關(guān)
loss = torch.mean(torch.pow(Y_hat-Y,2))
 
print(loss.data)
print(Y_hat.data)

tensor(17.8969)
tensor([[3.2613],
        [4.7322],
        [4.5037],
        [7.5899],
        [7.0973],
        [1.3287],
        [6.1473],
        [1.3492],
        [1.3911],
        [1.2150]])

2,計(jì)算圖在反向傳播后立即銷(xiāo)毀。

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定義后其正向傳播被立即執(zhí)行,與其后面的loss創(chuàng)建語(yǔ)句無(wú)關(guān)
loss = torch.mean(torch.pow(Y_hat-Y,2))
#計(jì)算圖在反向傳播后立即銷(xiāo)毀,如果需要保留計(jì)算圖, 需要設(shè)置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) 
#loss.backward() #如果再次執(zhí)行反向傳播將報(bào)錯(cuò)

參考鏈接:

  • https://www.zhihu.com/question/67209417/answer/344752405
  • https://blog.csdn.net/cs111211/article/details/126221102

總結(jié) 

到此這篇關(guān)于loss.item()用法和注意事項(xiàng)的文章就介紹到這了,更多相關(guān)loss.item()用法和注意事項(xiàng)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評(píng)論