Pytorch損失函數(shù)nn.NLLLoss2d()用法說明
最近做顯著星檢測用到了NLL損失函數(shù)
對于NLL函數(shù),需要自己計算log和softmax的概率值,然后從才能作為輸入
輸入 [batch_size, channel , h, w]
目標(biāo) [batch_size, h, w]
輸入的目標(biāo)矩陣,每個像素必須是類型.舉個例子。第一個像素是0,代表著類別屬于輸入的第1個通道;第二個像素是0,代表著類別屬于輸入的第0個通道,以此類推。
x = Variable(torch.Tensor([[[1, 2, 1], [2, 2, 1], [0, 1, 1]], [[0, 1, 3], [2, 3, 1], [0, 0, 1]]])) x = x.view([1, 2, 3, 3]) print("x輸入", x)
這里輸入x,并改成[batch_size, channel , h, w]的格式。
soft = nn.Softmax(dim=1)
log_soft = nn.LogSoftmax(dim=1)
然后使用softmax函數(shù)計算每個類別的概率,這里dim=1表示從在1維度
上計算,也就是channel維度。logsoftmax是計算完softmax后在計算log值
手動計算舉個栗子:第一個元素
y = Variable(torch.LongTensor([[1, 0, 1], [0, 0, 1], [1, 1, 1]])) y = y.view([1, 3, 3])
輸入label y,改變成[batch_size, h, w]格式
loss = nn.NLLLoss2d() out = loss(x, y) print(out)
輸入函數(shù),得到loss=0.7947
來手動計算
第一個label=1,則 loss=-1.3133
第二個label=0, 則loss=-0.3133
. … … loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223
是一致的
注意:這個函數(shù)會對每個像素做平均,每個batch也會做平均,這里有9個像素,1個batch_size。
補充知識:PyTorch:NLLLoss2d
我就廢話不多說了,大家還是直接看代碼吧~
import torch import torch.nn as nn from torch import autograd import torch.nn.functional as F inputs_tensor = torch.FloatTensor([ [[2, 4], [1, 2]], [[5, 3], [3, 0]], [[5, 3], [5, 2]], [[4, 2], [3, 2]], ]) inputs_tensor = torch.unsqueeze(inputs_tensor,0) # inputs_tensor = torch.unsqueeze(inputs_tensor,1) print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape targets_tensor = torch.LongTensor([ [0, 2], [2, 3] ]) targets_tensor = torch.unsqueeze(targets_tensor,0) print '--target size(nBatch x height x width): ', targets_tensor.shape inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True) inputs_variable = F.log_softmax(inputs_variable) targets_variable = autograd.Variable(targets_tensor) loss = nn.NLLLoss2d() output = loss(inputs_variable, targets_variable) print '--NLLLoss2d: {}'.format(output)
以上這篇Pytorch損失函數(shù)nn.NLLLoss2d()用法說明就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python列表切片操作實例探究(提取復(fù)制反轉(zhuǎn))
在Python中,列表切片是處理列表數(shù)據(jù)非常強大且靈活的方法,本文將全面探討Python中列表切片的多種用法,包括提取子列表、復(fù)制列表、反轉(zhuǎn)列表等操作,結(jié)合豐富的示例代碼進(jìn)行詳細(xì)講解2024-01-01Python中表達(dá)式x += y和x = x+y 的區(qū)別詳解
這篇文章主要跟大家介紹了關(guān)于Python中x += y和x = x+y 的區(qū)別的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起看看吧。2017-06-06Python參數(shù)解析器configparser簡介
configparser是python自帶的配置參數(shù)解析器,可以用于解析.config文件中的配置參數(shù),ini文件中由sections(節(jié)點)-key-value組成,這篇文章主要介紹了Python參數(shù)解析器configparser,需要的朋友可以參考下2022-12-12pandas series序列轉(zhuǎn)化為星期幾的實例
下面小編就為大家分享一篇pandas series序列轉(zhuǎn)化為星期幾的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Python數(shù)學(xué)建模StatsModels統(tǒng)計回歸之線性回歸示例詳解
這篇文章主要為大家介紹了Python數(shù)學(xué)建模中StatsModels統(tǒng)計回歸之線性回歸的示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助2021-10-10Python學(xué)習(xí)之模塊化程序設(shè)計示例詳解
程序設(shè)計的模塊化指的是在進(jìn)行程序設(shè)計時,把一個大的程序功能劃分為若干個小的程序模塊。每一個小程序模塊實現(xiàn)一個確定的功能,并且在這些小程序模塊實現(xiàn)的功能之間建立必要的聯(lián)系。本文將利用示例詳細(xì)介紹一下Python的模塊化程序設(shè)計,需要的可以參考一下2022-03-03將pandas.dataframe的數(shù)據(jù)寫入到文件中的方法
今天小編就為大家分享一篇將pandas.dataframe的數(shù)據(jù)寫入到文件中的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12對pandas replace函數(shù)的使用方法小結(jié)
今天小編就為大家分享一篇對pandas replace函數(shù)的使用方法小結(jié),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05