欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵?fù)p失

 更新時(shí)間:2023年02月02日 09:00:19   作者:gy笨瓜  
這篇文章主要介紹了PyTorch基礎(chǔ)之torch.nn.CrossEntropyLoss交叉熵?fù)p失講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

torch.nn.CrossEntropyLoss交叉熵?fù)p失

本文只考慮基本情況,未考慮加權(quán)。

torch.nnCrossEntropyLosss使用的公式

目標(biāo)類(lèi)別采用one-hot編碼

其中,class表示當(dāng)前樣本類(lèi)別在one-hot編碼中對(duì)應(yīng)的索引(從0開(kāi)始),

x[j]表示預(yù)測(cè)函數(shù)的第j個(gè)輸出

公式(1)表示先對(duì)預(yù)測(cè)函數(shù)使用softmax計(jì)算每個(gè)類(lèi)別的概率,再使用log(以e為底)計(jì)算后的相反數(shù)表示當(dāng)前類(lèi)別的損失,只表示其中一個(gè)樣本的損失計(jì)算方式,非全部樣本。

每個(gè)樣本使用one-hot編碼表示所屬類(lèi)別時(shí),只有一項(xiàng)為1,因此與基本的交叉熵?fù)p失函數(shù)相比,省略了其它值為0的項(xiàng),只剩(1)所表示的項(xiàng)。

sample

torch.nn.CrossEntropyLoss使用流程

torch.nn.CrossEntropyLoss為一個(gè)類(lèi),并非單獨(dú)一個(gè)函數(shù),使用到的相關(guān)簡(jiǎn)單參數(shù)會(huì)在使用中說(shuō)明,并非對(duì)所有參數(shù)進(jìn)行說(shuō)明。

首先創(chuàng)建類(lèi)對(duì)象

In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")

參數(shù)reduction默認(rèn)為"mean",表示對(duì)所有樣本的loss取均值,最終返回只有一個(gè)值

參數(shù)reduction取"none",表示保留每一個(gè)樣本的loss

計(jì)算損失

In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 與上述【sample】計(jì)算一致

實(shí)際計(jì)算損失值調(diào)用函數(shù)時(shí),傳入pred預(yù)測(cè)值與class_index類(lèi)別索引

在傳入每個(gè)類(lèi)別時(shí),class_index應(yīng)為一維,長(zhǎng)度為樣本個(gè)數(shù),每個(gè)元素表示對(duì)應(yīng)樣本的類(lèi)別索引,非one-hot編碼方式傳入

測(cè)試torch.nn.CrossEntropyLoss的reduction參數(shù)為默認(rèn)值"mean"

In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 與上述【sample】計(jì)算一致

交叉熵?fù)p失nn.CrossEntropyLoss()的真正計(jì)算過(guò)程

對(duì)于多分類(lèi)損失函數(shù)Cross Entropy Loss,就不過(guò)多的解釋?zhuān)W(wǎng)上的博客不計(jì)其數(shù)。在這里,講講對(duì)于CE Loss的一些真正的理解。

首先大部分博客給出的公式如下:

其中p為真實(shí)標(biāo)簽值,q為預(yù)測(cè)值。

在低維復(fù)現(xiàn)此公式,結(jié)果如下。在此強(qiáng)調(diào)一點(diǎn),pytorch中CE Loss并不會(huì)將輸入的target映射為one-hot編碼格式,而是直接取下標(biāo)進(jìn)行計(jì)算。

import torch
import torch.nn as nn
import math
import numpy as np

#官方的實(shí)現(xiàn)
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#輸出 tensor(1.1142)

#自己實(shí)現(xiàn)
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
    output = 0
    length = len(target)
    for i in range(length):
        hou = 0
        for j in input[i]:
            hou += np.log(input[i][target[i]])
        output += -hou
    return np.around(output / length, 4)
print(cross_entorpy(input, target))
#輸出 3.8162

我們按照官方給的CE Loss和根據(jù)公式得到的答案并不相同,說(shuō)明公式是有問(wèn)題的。

正確公式

實(shí)現(xiàn)代碼如下

import torch
import torch.nn as nn
import math
import numpy as np

entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#輸出 tensor(1.1142)
#%%
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
    output = 0
    length = len(target)
    for i in range(length):
        hou = 0
        for j in input[i]:
            hou += np.exp(j)
        output += -input[i][target[i]] + np.log(hou)
    return np.around(output / length, 4)
print(cross_entorpy(input, target))
#輸出 1.1142

對(duì)比自己實(shí)現(xiàn)的公式和官方給出的結(jié)果,可以驗(yàn)證公式的正確性。

觀察公式可以發(fā)現(xiàn)其實(shí)nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。

nn.logSoftmax(),公式如下

nn.NLLLoss(),公式如下

將nn.logSoftmax()作為變量帶入nn.NLLLoss()可得

因?yàn)?/p>

可看做一個(gè)常量,故上式可化簡(jiǎn)為:

對(duì)比nn.Cross Entropy Loss公式,結(jié)果顯而易見(jiàn)。

驗(yàn)證代碼如下。

import torch
import torch.nn as nn
import math
import numpy as np

entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
# 輸出為tensor(1.1142)
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print(output)
# 輸出為tensor(1.1142)

綜上,可得兩個(gè)結(jié)論

1.nn.Cross Entropy Loss的公式。

2.nn.Cross Entropy Loss為nn.logSoftmax()和nn.NLLLoss()的整合版本。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python實(shí)現(xiàn)平行坐標(biāo)圖的繪制(plotly)方式

    Python實(shí)現(xiàn)平行坐標(biāo)圖的繪制(plotly)方式

    今天小編就為大家分享一篇Python實(shí)現(xiàn)平行坐標(biāo)圖的繪制(plotly)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-11-11
  • PyQt5界面無(wú)響應(yīng)的解決方案

    PyQt5界面無(wú)響應(yīng)的解決方案

    如果在主線(xiàn)程執(zhí)行耗時(shí)操作,比如 循環(huán)、sleep、wait 異步線(xiàn)程執(zhí)行 會(huì)導(dǎo)致 UI 界面進(jìn)入無(wú)響應(yīng)狀態(tài),我們可以采用以下兩種方式異步處理:使用QThread 或 QTimer,本文給大家介紹了PyQt5界面無(wú)響應(yīng)的解決方案,需要的朋友可以參考下
    2024-05-05
  • Python數(shù)據(jù)結(jié)構(gòu)之列表與元組詳解

    Python數(shù)據(jù)結(jié)構(gòu)之列表與元組詳解

    序列是Python中最基本的數(shù)據(jù)結(jié)構(gòu)。序列中的每個(gè)元素都分配一個(gè)數(shù)字 - 它的位置,或索引,第一個(gè)索引是0,第二個(gè)索引是1,依此類(lèi)推,元組與列表類(lèi)似,不同之處在于元組的元素不能修改。元組使用小括號(hào),列表使用方括號(hào)
    2021-10-10
  • python基礎(chǔ)之函數(shù)的定義和調(diào)用

    python基礎(chǔ)之函數(shù)的定義和調(diào)用

    這篇文章主要介紹了python函數(shù)的定義和調(diào)用,實(shí)例分析了Python中返回一個(gè)返回值與多個(gè)返回值的方法,需要的朋友可以參考下
    2021-10-10
  • Windows64x下VScode下載過(guò)程

    Windows64x下VScode下載過(guò)程

    這篇文章主要介紹了Windows64x下VScode下載,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-09-09
  • Python 網(wǎng)絡(luò)編程之TCP客戶(hù)端/服務(wù)端功能示例【基于socket套接字】

    Python 網(wǎng)絡(luò)編程之TCP客戶(hù)端/服務(wù)端功能示例【基于socket套接字】

    這篇文章主要介紹了Python 網(wǎng)絡(luò)編程之TCP客戶(hù)端/服務(wù)端功能,結(jié)合實(shí)例形式分析了Python使用socket套接字實(shí)現(xiàn)TCP協(xié)議下的客戶(hù)端與服務(wù)器端數(shù)據(jù)傳輸操作技巧,需要的朋友可以參考下
    2019-10-10
  • 一文帶你玩轉(zhuǎn)python中的requests函數(shù)

    一文帶你玩轉(zhuǎn)python中的requests函數(shù)

    在Python中,requests庫(kù)是用于發(fā)送HTTP請(qǐng)求的常用庫(kù),因?yàn)樗峁┝撕?jiǎn)潔易用的接口,本文將深入探討requests庫(kù)的使用方法,感興趣的可以學(xué)習(xí)下
    2023-08-08
  • Python 的 Socket 編程

    Python 的 Socket 編程

    這篇文章最初發(fā)布的時(shí)候標(biāo)題是“Python的WebSocket編程”,坦白來(lái)說(shuō)有點(diǎn)文不對(duì)題。我們?cè)谶@里打算討論的僅僅是常規(guī)的socket編程。盡管 Web Socket 和常規(guī)sockets有點(diǎn)很相似,但又不是同一個(gè)東西。那我還是希望這篇文章對(duì)你們有點(diǎn)幫助。
    2015-03-03
  • python logging模塊的使用詳解

    python logging模塊的使用詳解

    這篇文章主要介紹了python logging模塊的使用,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2020-10-10
  • python 虛擬環(huán)境的創(chuàng)建與使用方法

    python 虛擬環(huán)境的創(chuàng)建與使用方法

    本文先介紹虛擬環(huán)境的基礎(chǔ)知識(shí)以及使用方法,然后再深入介紹虛擬環(huán)境背后的工作原理,需要的朋友可以參考下
    2021-06-06

最新評(píng)論