pytorch 實(shí)現(xiàn)二分類(lèi)交叉熵逆樣本頻率權(quán)重
通常,由于類(lèi)別不均衡,需要使用weighted cross entropy loss平衡。
def inverse_freq(label): """ 輸入label [N,1,H,W],1是channel數(shù)目 """ den = label.sum() # 0 _,_,h,w= label.shape num = h*w alpha = den/num # 0 return torch.tensor([alpha, 1-alpha]).cuda() # train ... loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))
補(bǔ)充:Pytorch踩坑記之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的區(qū)別和使用)
在Pytorch中的交叉熵函數(shù)的血淚史要從nn.CrossEntropyLoss()這個(gè)損失函數(shù)開(kāi)始講起。
從表面意義上看,這個(gè)函數(shù)好像是普通的交叉熵函數(shù),但是如果你看過(guò)一些Pytorch的資料,會(huì)告訴你這個(gè)函數(shù)其實(shí)是softmax()和交叉熵的結(jié)合體。
然而如果去官方看這個(gè)函數(shù)的定義你會(huì)發(fā)現(xiàn)是這樣子的:
哇,竟然是nn.LogSoftmax()和nn.NLLLoss()的結(jié)合體,這倆都是什么玩意兒啊。再看看你會(huì)發(fā)現(xiàn)甚至還有一個(gè)損失叫nn.Softmax()以及一個(gè)叫nn.nn.BCELoss()。
我們來(lái)探究下這幾個(gè)損失到底有何種關(guān)系。
nn.Softmax和nn.LogSoftmax
首先nn.Softmax()官網(wǎng)的定義是這樣的:
嗯...就是我們認(rèn)識(shí)的那個(gè)softmax。那nn.LogSoftmax()的定義也很直觀了:
果不其然就是Softmax取了個(gè)log??梢詫?xiě)個(gè)代碼測(cè)試一下:
import torch import torch.nn as nn a = torch.Tensor([1,2,3]) #定義Softmax softmax = nn.Softmax() sm_a = softmax=nn.Softmax() print(sm) #輸出:tensor([0.0900, 0.2447, 0.6652]) #定義LogSoftmax logsoftmax = nn.LogSoftmax() lsm_a = logsoftmax(a) print(lsm_a) #輸出tensor([-2.4076, -1.4076, -0.4076]),其中l(wèi)n(0.0900)=-2.4076
nn.NLLLoss
上面說(shuō)過(guò)nn.CrossEntropy()是nn.LogSoftmax()和nn.NLLLoss的結(jié)合,nn.NLLLoss官網(wǎng)給的定義是這樣的:
The negative log likelihood loss. It is useful to train a classification problem with C classes
負(fù)對(duì)數(shù)似然損失 ,看起來(lái)好像有點(diǎn)晦澀難懂,寫(xiě)個(gè)代碼測(cè)試一下:
import torch import torch.nn a = torch.Tensor([[1,2,3]]) nll = nn.NLLLoss() target1 = torch.Tensor([0]).long() target2 = torch.Tensor([1]).long() target3 = torch.Tensor([2]).long() #測(cè)試 n1 = nll(a,target1) #輸出:tensor(-1.) n2 = nll(a,target2) #輸出:tensor(-2.) n3 = nll(a,target3) #輸出:tensor(-3.)
看起來(lái)nn.NLLLoss做的事情是取出a中對(duì)應(yīng)target位置的值并取負(fù)號(hào),比如target1=0,就取a中index=0位置上的值再取負(fù)號(hào)為-1,那這樣做有什么意義呢,要結(jié)合nn.CrossEntropy往下看。
nn.CrossEntropy
看下官網(wǎng)給的nn.CrossEntropy()的表達(dá)式:
看起來(lái)應(yīng)該是softmax之后取了個(gè)對(duì)數(shù),寫(xiě)個(gè)簡(jiǎn)單代碼測(cè)試一下:
import torch import torch.nn as nn a = torch.Tensor([[1,2,3]]) target = torch.Tensor([2]).long() logsoftmax = nn.LogSoftmax() ce = nn.CrossEntropyLoss() nll = nn.NLLLoss() #測(cè)試CrossEntropyLoss cel = ce(a,target) print(cel) #輸出:tensor(0.4076) #測(cè)試LogSoftmax+NLLLoss lsm_a = logsoftmax(a) nll_lsm_a = nll(lsm_a,target) #輸出tensor(0.4076)
看來(lái)直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一樣的結(jié)果。為什么這樣呢,回想下交叉熵的表達(dá)式:
其中y是label,x是prediction的結(jié)果,所以其實(shí)交叉熵?fù)p失就是負(fù)的target對(duì)應(yīng)位置的輸出結(jié)果x再取-log。這個(gè)計(jì)算過(guò)程剛好就是先LogSoftmax()再NLLLoss()。
所以我認(rèn)為nn.CrossEntropyLoss其實(shí)應(yīng)該叫做softmaxloss更為合理一些,這樣就不會(huì)誤解了。
nn.BCELoss
你以為這就完了嗎,其實(shí)并沒(méi)有。還有一類(lèi)損失叫做BCELoss,寫(xiě)全了的話就是Binary Cross Entropy Loss,就是交叉熵應(yīng)用于二分類(lèi)時(shí)候的特殊形式,一般都和sigmoid一起用,表達(dá)式就是二分類(lèi)交叉熵:
直覺(jué)上和多酚類(lèi)交叉熵的區(qū)別在于,不僅考慮了的樣本,也考慮了
的樣本的損失。
總結(jié)
nn.LogSoftmax是在softmax的基礎(chǔ)上取自然對(duì)數(shù)
nn.NLLLoss是負(fù)的似然對(duì)數(shù)損失,但Pytorch的實(shí)現(xiàn)就是把對(duì)應(yīng)target上的數(shù)取出來(lái)再加個(gè)負(fù)號(hào),要在CrossEntropy中結(jié)合LogSoftmax來(lái)用
BCELoss是二分類(lèi)的交叉熵?fù)p失,Pytorch實(shí)現(xiàn)中和多分類(lèi)有區(qū)別
Pytorch是個(gè)深坑,讓我們一起扎根使用手冊(cè),結(jié)合實(shí)踐踏平這些坑吧暴風(fēng)哭泣。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
pycharm專(zhuān)業(yè)版遠(yuǎn)程登錄服務(wù)器的詳細(xì)教程
這篇文章主要介紹了pycharm專(zhuān)業(yè)版遠(yuǎn)程登錄服務(wù)器的詳細(xì)教程,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-09-09Python 實(shí)現(xiàn)取矩陣的部分列,保存為一個(gè)新的矩陣方法
今天小編就為大家分享一篇Python 實(shí)現(xiàn)取矩陣的部分列,保存為一個(gè)新的矩陣方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-11-11Python3.8 + Tkinter: Button設(shè)置image屬性不顯示的問(wèn)題及解決方法
這篇文章主要介紹了Python3.8 + Tkinter: Button設(shè)置image屬性不顯示的問(wèn)題,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-08-08Python+Pygame實(shí)戰(zhàn)之英文版猜字游戲的實(shí)現(xiàn)
這篇文章主要為大家介紹了如何利用Python中的Pygame模塊實(shí)現(xiàn)英文版猜單詞游戲,文中的示例代碼講解詳細(xì),對(duì)我們學(xué)習(xí)Python游戲開(kāi)發(fā)有一定幫助,需要的可以參考一下2022-08-08python實(shí)現(xiàn)CSF地面點(diǎn)濾波算法原理解析
這篇文章主要介紹了python實(shí)現(xiàn)CSF地面點(diǎn)濾波算法原理,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-08-08Python mplfinance庫(kù)繪制金融圖表實(shí)現(xiàn)數(shù)據(jù)可視化實(shí)例探究
mplfinance(Matplotlib Finance),它是基于Matplotlib的庫(kù),專(zhuān)門(mén)用于創(chuàng)建金融圖表和交互式金融數(shù)據(jù)可視化,本文將深入介紹?mplfinance,包括其基本概念、功能特性以及如何使用示例代碼創(chuàng)建各種金融圖表2024-01-01Python小紅書(shū)旋轉(zhuǎn)驗(yàn)證碼識(shí)別實(shí)戰(zhàn)教程
這篇文章主要介紹了Python小紅書(shū)旋轉(zhuǎn)驗(yàn)證碼識(shí)別實(shí)戰(zhàn)教程,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2023-08-08PyQt5打開(kāi)文件對(duì)話框QFileDialog實(shí)例代碼
這篇文章主要介紹了PyQt5打開(kāi)文件對(duì)話框QFileDialog實(shí)例代碼,分享了相關(guān)代碼示例,小編覺(jué)得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-02-02Python enumerate內(nèi)置庫(kù)用法解析
這篇文章主要介紹了Python enumerate內(nèi)置庫(kù)用法解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-02-02