PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵?fù)p失
torch.nn.CrossEntropyLoss交叉熵?fù)p失
本文只考慮基本情況,未考慮加權(quán)。
torch.nnCrossEntropyLosss使用的公式
目標(biāo)類(lèi)別采用one-hot編碼
其中,class表示當(dāng)前樣本類(lèi)別在one-hot編碼中對(duì)應(yīng)的索引(從0開(kāi)始),
x[j]表示預(yù)測(cè)函數(shù)的第j個(gè)輸出
公式(1)表示先對(duì)預(yù)測(cè)函數(shù)使用softmax計(jì)算每個(gè)類(lèi)別的概率,再使用log(以e為底)計(jì)算后的相反數(shù)表示當(dāng)前類(lèi)別的損失,只表示其中一個(gè)樣本的損失計(jì)算方式,非全部樣本。
每個(gè)樣本使用one-hot編碼表示所屬類(lèi)別時(shí),只有一項(xiàng)為1,因此與基本的交叉熵?fù)p失函數(shù)相比,省略了其它值為0的項(xiàng),只剩(1)所表示的項(xiàng)。
sample
torch.nn.CrossEntropyLoss使用流程
torch.nn.CrossEntropyLoss為一個(gè)類(lèi),并非單獨(dú)一個(gè)函數(shù),使用到的相關(guān)簡(jiǎn)單參數(shù)會(huì)在使用中說(shuō)明,并非對(duì)所有參數(shù)進(jìn)行說(shuō)明。
首先創(chuàng)建類(lèi)對(duì)象
In [1]: import torch In [2]: import torch.nn as nn In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
參數(shù)reduction默認(rèn)為"mean",表示對(duì)所有樣本的loss取均值,最終返回只有一個(gè)值
參數(shù)reduction取"none",表示保留每一個(gè)樣本的loss
計(jì)算損失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32) In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64) In [6]: loss_value = loss_function(pred, class_index) In [7]: loss_value Out[7]: tensor([1.5210, 0.6247]) # 與上述【sample】計(jì)算一致
實(shí)際計(jì)算損失值調(diào)用函數(shù)時(shí),傳入pred預(yù)測(cè)值與class_index類(lèi)別索引
在傳入每個(gè)類(lèi)別時(shí),class_index應(yīng)為一維,長(zhǎng)度為樣本個(gè)數(shù),每個(gè)元素表示對(duì)應(yīng)樣本的類(lèi)別索引,非one-hot編碼方式傳入
測(cè)試torch.nn.CrossEntropyLoss的reduction參數(shù)為默認(rèn)值"mean"
In [1]: import torch In [2]: import torch.nn as nn In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean") In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32) In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64) In [6]: loss_value = loss_function(pred, class_index) In [7]: loss_value Out[7]: 1.073 # 與上述【sample】計(jì)算一致
交叉熵?fù)p失nn.CrossEntropyLoss()的真正計(jì)算過(guò)程
對(duì)于多分類(lèi)損失函數(shù)Cross Entropy Loss,就不過(guò)多的解釋?zhuān)W(wǎng)上的博客不計(jì)其數(shù)。在這里,講講對(duì)于CE Loss的一些真正的理解。
首先大部分博客給出的公式如下:
其中p為真實(shí)標(biāo)簽值,q為預(yù)測(cè)值。
在低維復(fù)現(xiàn)此公式,結(jié)果如下。在此強(qiáng)調(diào)一點(diǎn),pytorch中CE Loss并不會(huì)將輸入的target映射為one-hot編碼格式,而是直接取下標(biāo)進(jìn)行計(jì)算。
import torch import torch.nn as nn import math import numpy as np #官方的實(shí)現(xiàn) entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) #輸出 tensor(1.1142) #自己實(shí)現(xiàn) input=np.array(input) target = np.array(target) def cross_entorpy(input, target): output = 0 length = len(target) for i in range(length): hou = 0 for j in input[i]: hou += np.log(input[i][target[i]]) output += -hou return np.around(output / length, 4) print(cross_entorpy(input, target)) #輸出 3.8162
我們按照官方給的CE Loss和根據(jù)公式得到的答案并不相同,說(shuō)明公式是有問(wèn)題的。
正確公式
實(shí)現(xiàn)代碼如下
import torch import torch.nn as nn import math import numpy as np entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) #輸出 tensor(1.1142) #%% input=np.array(input) target = np.array(target) def cross_entorpy(input, target): output = 0 length = len(target) for i in range(length): hou = 0 for j in input[i]: hou += np.exp(j) output += -input[i][target[i]] + np.log(hou) return np.around(output / length, 4) print(cross_entorpy(input, target)) #輸出 1.1142
對(duì)比自己實(shí)現(xiàn)的公式和官方給出的結(jié)果,可以驗(yàn)證公式的正確性。
觀察公式可以發(fā)現(xiàn)其實(shí)nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。
nn.logSoftmax(),公式如下
nn.NLLLoss(),公式如下
將nn.logSoftmax()作為變量帶入nn.NLLLoss()可得
因?yàn)?/p>
可看做一個(gè)常量,故上式可化簡(jiǎn)為:
對(duì)比nn.Cross Entropy Loss公式,結(jié)果顯而易見(jiàn)。
驗(yàn)證代碼如下。
import torch import torch.nn as nn import math import numpy as np entroy=nn.CrossEntropyLoss() input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],]) target = torch.tensor([0,1,2]) output = entroy(input, target) print(output) # 輸出為tensor(1.1142) m = nn.LogSoftmax() loss = nn.NLLLoss() input=m(input) output = loss(input, target) print(output) # 輸出為tensor(1.1142)
綜上,可得兩個(gè)結(jié)論
1.nn.Cross Entropy Loss的公式。
2.nn.Cross Entropy Loss為nn.logSoftmax()和nn.NLLLoss()的整合版本。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)平行坐標(biāo)圖的繪制(plotly)方式
今天小編就為大家分享一篇Python實(shí)現(xiàn)平行坐標(biāo)圖的繪制(plotly)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-11-11Python數(shù)據(jù)結(jié)構(gòu)之列表與元組詳解
序列是Python中最基本的數(shù)據(jù)結(jié)構(gòu)。序列中的每個(gè)元素都分配一個(gè)數(shù)字 - 它的位置,或索引,第一個(gè)索引是0,第二個(gè)索引是1,依此類(lèi)推,元組與列表類(lèi)似,不同之處在于元組的元素不能修改。元組使用小括號(hào),列表使用方括號(hào)2021-10-10python基礎(chǔ)之函數(shù)的定義和調(diào)用
這篇文章主要介紹了python函數(shù)的定義和調(diào)用,實(shí)例分析了Python中返回一個(gè)返回值與多個(gè)返回值的方法,需要的朋友可以參考下2021-10-10Python 網(wǎng)絡(luò)編程之TCP客戶(hù)端/服務(wù)端功能示例【基于socket套接字】
這篇文章主要介紹了Python 網(wǎng)絡(luò)編程之TCP客戶(hù)端/服務(wù)端功能,結(jié)合實(shí)例形式分析了Python使用socket套接字實(shí)現(xiàn)TCP協(xié)議下的客戶(hù)端與服務(wù)器端數(shù)據(jù)傳輸操作技巧,需要的朋友可以參考下2019-10-10一文帶你玩轉(zhuǎn)python中的requests函數(shù)
在Python中,requests庫(kù)是用于發(fā)送HTTP請(qǐng)求的常用庫(kù),因?yàn)樗峁┝撕?jiǎn)潔易用的接口,本文將深入探討requests庫(kù)的使用方法,感興趣的可以學(xué)習(xí)下2023-08-08python 虛擬環(huán)境的創(chuàng)建與使用方法
本文先介紹虛擬環(huán)境的基礎(chǔ)知識(shí)以及使用方法,然后再深入介紹虛擬環(huán)境背后的工作原理,需要的朋友可以參考下2021-06-06