pytorch交叉熵?fù)p失函數(shù)的weight參數(shù)的使用
首先
必須將權(quán)重也轉(zhuǎn)為Tensor的cuda格式;
然后
將該class_weight作為交叉熵函數(shù)對(duì)應(yīng)參數(shù)的輸入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
補(bǔ)充:關(guān)于pytorch的CrossEntropyLoss的weight參數(shù)
首先這個(gè)weight參數(shù)比想象中的要考慮的多
你可以試試下面代碼
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,1,1]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.4803)
這里的手動(dòng)計(jì)算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加權(quán)呢?
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.6075)
手算發(fā)現(xiàn),并不是單純的那權(quán)重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
發(fā)現(xiàn)了么,加權(quán)后,除以的是權(quán)重的和,不是數(shù)目的和。
我們?cè)衮?yàn)證一遍:
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5]) outputs = torch.LongTensor([0,1,2,2]) inputs = inputs.view((1,3,4)) outputs = outputs.view((1,4)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(weight=weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人對(duì)loss的CE計(jì)算過程有疑問,我這里細(xì)致寫寫交叉熵的計(jì)算過程,就拿最后一個(gè)例子的loss4的計(jì)算說明

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
解析python調(diào)用函數(shù)加括號(hào)和不加括號(hào)的區(qū)別
這篇文章主要介紹了python調(diào)用函數(shù)加括號(hào)和不加括號(hào)的區(qū)別,不帶括號(hào)時(shí),調(diào)用的是這個(gè)函數(shù)本身 ,是整個(gè)函數(shù)體,是一個(gè)函數(shù)對(duì)象,不須等該函數(shù)執(zhí)行完成,具體實(shí)例代碼跟隨小編一起看看吧2021-10-10
Python簡(jiǎn)單獲取網(wǎng)卡名稱及其IP地址的方法【基于psutil模塊】
這篇文章主要介紹了Python簡(jiǎn)單獲取網(wǎng)卡名稱及其IP地址的方法,結(jié)合實(shí)例形式分析了Python基于psutil模塊針對(duì)本機(jī)網(wǎng)卡硬件信息的讀取操作簡(jiǎn)單使用技巧,需要的朋友可以參考下2018-05-05
python使用socket向客戶端發(fā)送數(shù)據(jù)的方法
這篇文章主要介紹了python使用socket向客戶端發(fā)送數(shù)據(jù)的方法,涉及Python使用socket實(shí)現(xiàn)數(shù)據(jù)通信的技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-04-04
python中的繼承機(jī)制super()函數(shù)詳解
這篇文章主要介紹了python中的繼承機(jī)制super()函數(shù)詳解,super 是用來解決多重繼承問題的,直接用類名調(diào)用父類方法在使用單繼承的時(shí)候沒問題,但是如果使用多繼承,會(huì)涉及到查找順序、重復(fù)調(diào)用等問題,需要的朋友可以參考下2023-08-08
python實(shí)現(xiàn)簡(jiǎn)易通訊錄修改版
這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)簡(jiǎn)易通訊錄的修改版,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03

