PyTorch中常見損失函數(shù)的使用詳解
損失函數(shù)
損失函數(shù),又叫目標(biāo)函數(shù)。在編譯神經(jīng)網(wǎng)絡(luò)模型必須的兩個參數(shù)之一。另一個必不可少的就是優(yōu)化器,我將在后面詳解到。
重點(diǎn)
損失函數(shù)是指計(jì)算機(jī)標(biāo)簽值和預(yù)測值直接差異的函數(shù)。
這里我們會結(jié)束幾種常見的損失函數(shù)的計(jì)算方法,pytorch中也是以及定義了很多類型的預(yù)定義函數(shù),具體的公式不需要去深究(學(xué)了也不一定remember),這里暫時(shí)能做就是了解。
我們先來定義兩個二維的數(shù)組,然后用不同的損失函數(shù)計(jì)算其損失值。
import torch from torch.autograd import Variable import torch.nn as nn sample=Variable(torch.ones(2,2)) a=torch.Tensor(2,2) a[0,0]=0 a[0,1]=1 a[1,0]=2 a[1,1]=3 target=Variable(a) print(sample,target)
這里:
sample的值為tensor([[1., 1.],[1., 1.]])
target的值為tensor([[0., 1.],[2., 3.]])
nn.L1Loss
L1Loss計(jì)算方法很簡單,取預(yù)測值和真實(shí)值的絕對誤差的平均數(shù)。
loss=FunLoss(sample,target)['L1Loss'] print(loss)
在控制臺中打印出來是
tensor(1.)
它的計(jì)算過程是這樣的:(∣0−1∣+∣1−1∣+∣2−1∣+∣3−1∣)/4=1,先計(jì)算的是絕對值求和,然后再平均。
nn.SmoothL1Loss
SmoothL1Loss的誤差在(-1,1)上是平方損失,其他情況是L1損失。
loss=FunLoss(sample,target)['SmoothL1Loss'] print(loss)
在控制臺中打印出來是
tensor(0.6250)
nn.MSELoss
平方損失函數(shù)。其計(jì)算公式是預(yù)測值和真實(shí)值之間的平方和的平均數(shù)。
loss=FunLoss(sample,target)['MSELoss'] print(loss)
在控制臺中打印出來是
tensor(1.5000)
nn.CrossEntropyLoss
交叉熵?fù)p失公式
此公式常在圖像分類神經(jīng)網(wǎng)絡(luò)模型中會常常用到。
loss=FunLoss(sample,target)['CrossEntropyLoss'] print(loss)
在控制臺中打印出來是
tensor(2.0794)
nn.NLLLoss
負(fù)對數(shù)似然損失函數(shù)
需要注意的是,這里的xlabel和上面的交叉熵?fù)p失里的是不一樣的,這里是經(jīng)過log運(yùn)算后的數(shù)值。這個損失函數(shù)一般用在圖像識別的模型上。
loss=FunLoss(sample,target)['NLLLoss'] print(loss)
這里,控制臺報(bào)錯,需要0D或1D目標(biāo)張量,不支持多目標(biāo)。可能需要其他的一些條件,這里我們?nèi)绻龅搅嗽僬f。
損失函數(shù)模塊化設(shè)計(jì)
class FunLoss(): def __init__(self, sample, target): self.sample = sample self.target = target self.loss = { 'L1Loss': nn.L1Loss(), 'SmoothL1Loss': nn.SmoothL1Loss(), 'MSELoss': nn.MSELoss(), 'CrossEntropyLoss': nn.CrossEntropyLoss(), 'NLLLoss': nn.NLLLoss() } def __getitem__(self, loss_type): if loss_type in self.loss: loss_func = self.loss[loss_type] return loss_func(self.sample, self.target) else: raise KeyError(f"Invalid loss type '{loss_type}'") if __name__=="__main__": loss=FunLoss(sample,target)['NLLLoss'] print(loss)
總結(jié)
這篇博客適合那些希望了解在PyTorch中常見損失函數(shù)的讀者。通過FunLoss我們自己也能簡單的去調(diào)用。
到此這篇關(guān)于PyTorch中常見損失函數(shù)的使用詳解的文章就介紹到這了,更多相關(guān)PyTorch損失函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法
今天小編就為大家分享一篇pytorch 轉(zhuǎn)換矩陣的維數(shù)位置方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12python spilt()分隔字符串的實(shí)現(xiàn)示例
split() 方法可以實(shí)現(xiàn)將一個字符串按照指定的分隔符切分成多個子串,本文介紹了spilt的具體使用,感興趣的可以了解一下2021-05-05Python issubclass和isinstance函數(shù)的具體使用
本文主要介紹了Python issubclass和isinstance函數(shù)的具體使用,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02Python使用Selenium模塊實(shí)現(xiàn)模擬瀏覽器抓取淘寶商品美食信息功能示例
這篇文章主要介紹了Python使用Selenium模塊實(shí)現(xiàn)模擬瀏覽器抓取淘寶商品美食信息功能,涉及Python基于re模塊的正則匹配及selenium模塊的頁面抓取等相關(guān)操作技巧,需要的朋友可以參考下2018-07-07python開啟多個子進(jìn)程并行運(yùn)行的方法
這篇文章主要介紹了python開啟多個子進(jìn)程并行運(yùn)行的方法,涉及Python進(jìn)程操作的相關(guān)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-04-04通過conda把已有虛擬環(huán)境的python版本進(jìn)行降級操作指南
當(dāng)使用conda創(chuàng)建虛擬環(huán)境時(shí),有時(shí)候可能會遇到python版本不對的問題,下面這篇文章主要給大家介紹了關(guān)于如何通過conda把已有虛擬環(huán)境的python版本進(jìn)行降級操作的相關(guān)資料,需要的朋友可以參考下2024-05-05