欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中交叉熵?fù)p失函數(shù)的使用小細(xì)節(jié)

 更新時(shí)間:2023年02月02日 09:14:41   作者:Mr_health  
這篇文章主要介紹了pytorch中交叉熵?fù)p失函數(shù)的使用細(xì)節(jié),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

目前pytorch中的交叉熵?fù)p失函數(shù)主要分為以下三類(lèi),我們將其使用的要點(diǎn)以及場(chǎng)景做一下總結(jié)。

類(lèi)型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()

  • 輸入:非onehot label + logit。函數(shù)會(huì)自動(dòng)將logit通過(guò)softmax映射為概率。
  • 使用場(chǎng)景:都是應(yīng)用于互斥的分類(lèi)任務(wù),如典型的二分類(lèi)以及互斥的多分類(lèi)。
  • 網(wǎng)絡(luò):分類(lèi)個(gè)數(shù)即為網(wǎng)絡(luò)的輸出節(jié)點(diǎn)數(shù)

類(lèi)型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()

  • 輸入:logit。函數(shù)會(huì)自動(dòng)將logit通過(guò)sidmoid映射為概率。
  • 使用場(chǎng)景:① 二分類(lèi) ② 非互斥多分類(lèi)
  • 網(wǎng)絡(luò):使用這類(lèi)損失函數(shù)需要將網(wǎng)絡(luò)輸出的每一個(gè)節(jié)點(diǎn)當(dāng)作一個(gè)二分類(lèi)的節(jié)點(diǎn)                  

①當(dāng)為標(biāo)準(zhǔn)的二分類(lèi)時(shí),網(wǎng)絡(luò)的輸出節(jié)點(diǎn)為1

②當(dāng)為非互斥的多分類(lèi)時(shí),分類(lèi)個(gè)數(shù)即為網(wǎng)絡(luò)的輸出節(jié)點(diǎn)數(shù)

類(lèi)型三:F.binary_cross_entropy()與torch.nn.BCELoss()

  • 輸入:prob(概率)。這個(gè)概率可以由softmax計(jì)算而來(lái),也可以由sigmoid計(jì)算而來(lái)。兩種不同的概率映射方式對(duì)應(yīng)不同的分類(lèi)任務(wù)。
  • 使用場(chǎng)景:① 二分類(lèi) ② 非互斥多分類(lèi)
  • 網(wǎng)絡(luò):①標(biāo)準(zhǔn)的二分類(lèi)任務(wù):網(wǎng)絡(luò)的輸出節(jié)點(diǎn)可以為1,此時(shí)概率必須由sigmoid進(jìn)行映射;                      

網(wǎng)絡(luò)的輸出節(jié)點(diǎn)可以為2,此時(shí)概率必須由softmax進(jìn)行映射。

②當(dāng)為非互斥的多分類(lèi)時(shí),分類(lèi)個(gè)數(shù)即為網(wǎng)絡(luò)的輸出節(jié)點(diǎn)數(shù),此時(shí)概率必須由sigmoid進(jìn)行映射

1.二分類(lèi)

類(lèi)型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()

  • 網(wǎng)絡(luò)的輸出節(jié)點(diǎn)為2,表示real和fake(類(lèi)別1和類(lèi)別2)

類(lèi)型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()

  • 由于這兩個(gè)函數(shù)自帶sigmoid函數(shù),要想完成二分類(lèi),網(wǎng)絡(luò)的輸出節(jié)點(diǎn)個(gè)數(shù)必須設(shè)置為1

類(lèi)型三:F.binary_cross_entropy()與torch.nn.BCELoss(),以下兩種情況都可以使用:

  • 當(dāng)網(wǎng)絡(luò)輸出的節(jié)點(diǎn)為2時(shí),一個(gè)節(jié)點(diǎn)為real另一個(gè)節(jié)點(diǎn)為fake,那么必然要采用softmax將logits映射為概率(兩個(gè)節(jié)點(diǎn)的概率和為1),此時(shí)該函數(shù)輸入為onehot label + softmax prob,計(jì)算出的交叉熵?fù)p失與類(lèi)型一結(jié)算結(jié)果相同。
  • 當(dāng)網(wǎng)絡(luò)的輸出節(jié)點(diǎn)為1時(shí),也就是后面我們要講的GAN的交叉熵?fù)p失的實(shí)現(xiàn),那么則需要使用sigmoid函數(shù)來(lái)進(jìn)行映射。

這里我們以網(wǎng)絡(luò)輸出節(jié)點(diǎn)為2為例,由于類(lèi)型二要求網(wǎng)絡(luò)的輸出節(jié)點(diǎn)為1,因此暫時(shí)不納入討論,主要討論類(lèi)型和類(lèi)型三。

測(cè)試代碼如下:

(網(wǎng)絡(luò)輸出節(jié)點(diǎn)為1的二分類(lèi)就是目前GAN的實(shí)現(xiàn)方式,該方式下類(lèi)型一的函數(shù)不可用,只能采用類(lèi)型二和類(lèi)型三,后面將會(huì)詳細(xì)討論)

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1],
? ? ? ? ? ? ? ? ? ? [-1.587, ?-0.5907]])
classes = 2
label = torch.tensor([1, 1])
logits = torch.from_numpy(logits).float()
?
#F.cross_entropy
loss1 = F.cross_entropy(logits, label) ?
print(loss1)
?
#nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
?
#可以看到,loss1是等于loss2的
?
prob = softmax(logits) ?#計(jì)算概率
one_hot_label = one_hot(label, classes)
?
#F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob, one_hot_label) #輸入概率和one-hot
print(loss3)
?
#torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob, one_hot_label)
print(loss4)
?
#同理,loss3是等于loss4的
?
#手動(dòng)實(shí)現(xiàn)二分類(lèi)的交叉熵?fù)p失
shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1)) ?#手動(dòng)實(shí)現(xiàn)
print(shixian)

2.多分類(lèi)

此時(shí)網(wǎng)絡(luò)輸出時(shí)多節(jié)點(diǎn),每一個(gè)節(jié)點(diǎn)代表一個(gè)類(lèi)別。

類(lèi)型一:F.cross_entropy()與torch.nn.CrossEntropyLoss()

  • 可以用于多分類(lèi)的互斥任務(wù),輸入非onehot label + logit。但是不能用于多分類(lèi)多標(biāo)簽任務(wù)。因?yàn)檫@兩個(gè)函數(shù)中自帶的softmax將網(wǎng)絡(luò)的每一個(gè)節(jié)點(diǎn)都當(dāng)作時(shí)互斥的獨(dú)立節(jié)點(diǎn),每個(gè)節(jié)點(diǎn)的概率和為1,因?yàn)楦怕首畲蟮哪莻€(gè)節(jié)點(diǎn)的類(lèi)別會(huì)被當(dāng)為最終的預(yù)測(cè)類(lèi)別

類(lèi)型二:F.binary_cross_entropy_with_logits()與torch.nn.BCEWithLogitsLoss()

  • 不能用于多分類(lèi)的互斥任務(wù),只能用于多分類(lèi)的非互斥任務(wù)

類(lèi)型三:F.binary_cross_entropy()與torch.nn.BCELoss()

  • 與類(lèi)型二一樣,不能用于多分類(lèi)的互斥任務(wù),只能用于多分類(lèi)的非互斥任務(wù)。

這里我們首先討論下類(lèi)型一和類(lèi)型三,為什么類(lèi)型三不能用于多分類(lèi)的互斥任務(wù),只能用于多分類(lèi)多標(biāo)簽的分類(lèi)任務(wù)?我們來(lái)看一段代碼,這里有三個(gè)類(lèi)別,兩個(gè)樣本。

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1, 0.2],
? ? ? ? ? ? ? ? ? ? [-1.587, ?-0.5907, 0.3]])
classes = 3
label = torch.tensor([1, 2])
logits = torch.from_numpy(logits).float()
?
### F.cross_entropy
loss1 = F.cross_entropy(logits, label) ?
print(loss1)
?
### nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
##loss1 = loss2

上面是采用類(lèi)型一的兩個(gè)函數(shù)計(jì)算而來(lái),loss1 = loss2 = 0.9833

然后我們用類(lèi)型三的函數(shù)來(lái)實(shí)現(xiàn),同樣將logit通過(guò)softmax映射為概率,運(yùn)行后的結(jié)果可以看loss3 =loss4 = 0.5649,不等于類(lèi)型一的函數(shù)的結(jié)果的。

prob_softmax = softmax(logits) ?#計(jì)算概率
one_hot_label = one_hot(label, classes)
?
## F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #輸入概率和one-hot
print(loss3)
?
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob_softmax, one_hot_label)
print(loss4)

最后我們?cè)偈謩?dòng)實(shí)現(xiàn)類(lèi)型三的損失究竟是怎么得到的:

#手動(dòng)實(shí)現(xiàn)
shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax))
print(shixian)

可以看出來(lái),F(xiàn).binary_cross_entropy()與torch.nn.BCELoss()是將網(wǎng)絡(luò)的每個(gè)節(jié)點(diǎn)看作是一個(gè)二分類(lèi)的節(jié)點(diǎn)來(lái)計(jì)算交叉熵?fù)p失的。

進(jìn)一步來(lái)討論下類(lèi)型二和類(lèi)型三的一致性,代碼如下。由于類(lèi)型二中函數(shù)自動(dòng)將logit通過(guò)sigloid函數(shù)映射為概率,為了檢驗(yàn)一致性性,我門(mén)也需要通過(guò)sigmoid計(jì)算類(lèi)型三所需要的概率。

最后可以看到下面的輸出均為0.6378

sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits) ?#計(jì)算概率
?
##類(lèi)型二
##F.binary_cross_entropy_with_logits
loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label)
print(loss5)
?
##torch.nn.BCEWithLogitsLoss()
BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
loss6 = BCEWithLogitsLoss(logits, one_hot_label)
print(loss6)
?
##類(lèi)型三
##F.binary_cross_entropy
loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #輸入概率和one-hot
print(loss7)
?
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss8 = adversarial_loss(prob_sig, one_hot_label)
print(loss8)
?
#手動(dòng)實(shí)現(xiàn)
shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig))
print(shixian)

3. GAN中的實(shí)現(xiàn):二分類(lèi)

GAN中的判別器出的損失就是典型的最小化二分類(lèi)的交叉熵?fù)p失。但是在實(shí)現(xiàn)上,與二分類(lèi)網(wǎng)絡(luò)不同。

  • 一般的二分類(lèi)網(wǎng)絡(luò),輸出有兩個(gè)節(jié)點(diǎn),分別表示real和fake的logit(或者概率)。
  • GAN的判別器,輸出只有一個(gè)節(jié)點(diǎn),表示的是樣本屬于real的logit(或者概率)。

正因?yàn)榕袆e器的輸出是一維,類(lèi)型一的兩個(gè)函數(shù)F.cross_entropy()與torch.nn.CrossEntropyLoss()是沒(méi)有辦法使用的,因?yàn)檫@兩個(gè)函數(shù)要求輸入是二維的,即分別在real和fake的logit。因此只能采用類(lèi)型二或者類(lèi)型三的函數(shù)。

很多GAN網(wǎng)絡(luò)采用的二分類(lèi)交叉熵?fù)p失函數(shù)如下:

#類(lèi)型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y)
#類(lèi)型三:
adversarial_loss_3 = torch.nn.BCELoss(p,y)

前面我們講到,類(lèi)型二和類(lèi)型三的函數(shù)都是將每一個(gè)節(jié)點(diǎn)視為一個(gè)二分類(lèi)的節(jié)點(diǎn),因此對(duì)于每一個(gè)給節(jié)點(diǎn),其具體的表達(dá)式可以寫(xiě)為:

#類(lèi)型二:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
# 其中l(wèi)ogit表示判斷為real的logit
# y=1表示real
# y=0表示fake
?
#類(lèi)型三:
torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p))
# 其中p表示判斷為real的概率
# y=1表示real
# y=0表示fake

3.1 判別器損失計(jì)算

判別器輸出維度為1,輸出logit,有兩個(gè)樣本,都為fake圖像

logits = np.array([1.2, -0.5])
logits = torch.from_numpy(logits).float()
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits) ?#計(jì)算概率
?
label = torch.tensor([1, 1]).float()
?
#類(lèi)型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss()
loss_2 = adversarial_loss_2(logits, 1-label) ?#因?yàn)槭莊ake,需要將y設(shè)置為0
print(loss_2)
?
#類(lèi)型三:
adversarial_loss_3 = torch.nn.BCELoss()
loss_3 = adversarial_loss_3(prob_sig, 1-label) #因?yàn)槭莊ake,需要將y設(shè)置為0
print(loss_3)
#輸出均為0.9687

 通過(guò)上述代碼可以分析如下:

(1)當(dāng)樣本為fake時(shí),網(wǎng)絡(luò)輸出其為real的logit:

  • 對(duì)于類(lèi)型二:torch.nn.BCEWithLogitsLoss(logit,0),即直接輸入logit。由于樣本的實(shí)際類(lèi)別為fake,根據(jù)交叉熵?fù)p失公式,要將為y設(shè)置為0,相當(dāng)于告訴函數(shù)我輸入的樣本是fake。
  • 對(duì)于類(lèi)型三:torch.nn.BCELoss(prob, 0),此時(shí)prob等于公式中的p,由于樣本的實(shí)際類(lèi)別為fake,與類(lèi)型二一致,要將為y設(shè)置為0。

(2)樣本為real,網(wǎng)絡(luò)輸出其為real的logit:

  • 對(duì)于類(lèi)型二:torch.nn.BCEWithLogitsLoss(logit,1),即直接輸入logit。由于樣本的實(shí)際類(lèi)別也為real,根據(jù)交叉熵?fù)p失公式,要將為y設(shè)置為1,這樣就計(jì)算了 ylog(sigmoid(logit))
  • 對(duì)于類(lèi)型三:torch.nn.BCELoss(prob, 1),此時(shí)prob等于公式中的p,樣本的實(shí)際類(lèi)別也為real,與類(lèi)型二一致,要將為y設(shè)置為1,這樣就計(jì)算了 ylog(p)

GAN網(wǎng)絡(luò)在更新判別器時(shí),代碼一般如下:

criterion = torch.nn.BCELoss()
real_out = D(real_img) ?# 將真實(shí)圖片放入判別器中
d_loss_real = criterion(real_out, 1) ?# 真實(shí)樣本的損失
?
fake_img = G(z) ?# 隨機(jī)噪聲放入生成網(wǎng)絡(luò)中,生成一張假的圖片
fake_out = D(fake_img) ?# 判別器判斷假的圖片,
d_loss_fake = criterion(fake_out, 0) ?# 生成樣本的損失
?
d_loss = d_loss_real + d_loss_fake ?# ?兩個(gè)相加 就是標(biāo)準(zhǔn)的交叉熵?fù)p失
?
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

3.2 生成器的損失計(jì)算

前面判別器處的損失是最小化交叉熵?fù)p失:

min - (ylog(p) + (1-y)log(1-p))

那么生成器與之相反就是最大化交叉熵?fù)p失:

max - (ylog(p) + (1-y)log(1-p))

因?yàn)檎鎸?shí)樣本于與生成器無(wú)關(guān),因此可以轉(zhuǎn)變?yōu)閙in log(1-p)

max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)

上述形式為飽和形式,轉(zhuǎn)變?yōu)榉秋柡腿缦隆?/p>

min -log(p)

可以看到上式子在形式上就是將fake圖像當(dāng)作real圖像進(jìn)行優(yōu)化。

可以這么理解:生成器的作用的就是盡可能生成逼近與real的fake,由于判別器判斷的結(jié)果p就是表示圖像為real的概率,那么生成器就希望p越高越好。而在訓(xùn)練判別器時(shí),判別器對(duì)real的優(yōu)化就是讓其p越高越好,即盡可能的區(qū)分real和fake。

因此在更新生成器時(shí),fake處的損失與更新判別器在real處的損失在邏輯上是一致的。

criterion = torch.nn.BCELoss()
fake_img = G(z) ?# 隨機(jī)噪聲放入生成網(wǎng)絡(luò)中,生成一張假的圖片
fake_out = D(fake_img) ?# 判別器判斷假的圖片,
G_loss = criterion(fake_out, 1) ?# 假樣本的損失
?
?
optimizer_G.zero_grad()
G_loss .backward()
optimizer_G.step()

3.3 小結(jié)

在GAN網(wǎng)絡(luò)中,由于輸出網(wǎng)絡(luò)只有一個(gè)節(jié)點(diǎn),表示圖像屬于real的logit或者prob,因此一般使用類(lèi)型二和類(lèi)型三的損失函數(shù)。

兩類(lèi)函數(shù)的實(shí)現(xiàn)如下:

torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))

因?yàn)樯鲜鰧?shí)現(xiàn):

  • 在更新判別器時(shí):real圖像后面label為1,fake圖像后面label為0。分別計(jì)算real和fake的損失相加。
  • 在更新判別器時(shí):與real圖像無(wú)關(guān),fake圖像后面label為1,更新。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Pycharm中安裝Pygal并使用Pygal模擬擲骰子(推薦)

    Pycharm中安裝Pygal并使用Pygal模擬擲骰子(推薦)

    這篇文章主要介紹了Pycharm中安裝Pygal并使用Pygal模擬擲骰子,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-04-04
  • Python實(shí)現(xiàn)數(shù)字圖像處理染色體計(jì)數(shù)示例

    Python實(shí)現(xiàn)數(shù)字圖像處理染色體計(jì)數(shù)示例

    這篇文章主要為大家介紹了Python實(shí)現(xiàn)數(shù)字圖像處理染色體計(jì)數(shù)示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-06-06
  • Python面向?qū)ο蟪绦蛟O(shè)計(jì)之繼承、多態(tài)原理與用法詳解

    Python面向?qū)ο蟪绦蛟O(shè)計(jì)之繼承、多態(tài)原理與用法詳解

    這篇文章主要介紹了Python面向?qū)ο蟪绦蛟O(shè)計(jì)之繼承、多態(tài),結(jié)合實(shí)例形式分析了Python面向?qū)ο蟪绦蛟O(shè)計(jì)中繼承、多態(tài)的相關(guān)概念、原理、用法及操作注意事項(xiàng),需要的朋友可以參考下
    2020-03-03
  • 淺析Python中線(xiàn)程以及線(xiàn)程阻塞

    淺析Python中線(xiàn)程以及線(xiàn)程阻塞

    這篇文章主要為大家簡(jiǎn)單介紹一下Python中線(xiàn)程以及線(xiàn)程阻塞的相關(guān)知識(shí),文中的示例代碼講解詳細(xì),具有一定的學(xué)習(xí)價(jià)值,感興趣的小伙伴可以了解一下
    2023-04-04
  • python IDLE 背景以及字體大小的修改方法

    python IDLE 背景以及字體大小的修改方法

    這篇文章主要介紹了python IDLE 背景以及字體的修改方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-07-07
  • Django中Middleware中的函數(shù)詳解

    Django中Middleware中的函數(shù)詳解

    這篇文章主要介紹了Django中Middleware中的函數(shù)詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-07-07
  • Python-jenkins 獲取job構(gòu)建信息方式

    Python-jenkins 獲取job構(gòu)建信息方式

    這篇文章主要介紹了Python-jenkins 獲取job構(gòu)建信息方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2020-05-05
  • Python中6種中文文本情感分析的方法詳解

    Python中6種中文文本情感分析的方法詳解

    中文文本情感分析是一種將自然語(yǔ)言處理技術(shù)應(yīng)用于文本數(shù)據(jù)的方法,它可以幫助我們了解文本中所表達(dá)的情感傾向,Python中就有多種方法可以進(jìn)行中文文本情感分析,下面就來(lái)和大家簡(jiǎn)單講講
    2023-06-06
  • python實(shí)現(xiàn)層次聚類(lèi)的方法

    python實(shí)現(xiàn)層次聚類(lèi)的方法

    層次聚類(lèi)就是一層一層的進(jìn)行聚類(lèi),可以由上向下把大的類(lèi)別(cluster)分割,叫作分裂法,這篇文章主要介紹了python實(shí)現(xiàn)層次聚類(lèi)的方法,需要的朋友可以參考下
    2021-11-11
  • python中的格式化輸出方法

    python中的格式化輸出方法

    這篇文章主要介紹了python中的格式化輸出方法,?數(shù)據(jù)可以以人類(lèi)可讀的形式打印,或?qū)懭胛募怨?lái)使用,甚至可以以某種其他指定的形式。?用戶(hù)通常希望對(duì)輸出格式進(jìn)行更多控制,而不是簡(jiǎn)單地打印以空格分隔的值,更多格式化輸出方式需要的朋友可以參考下面文章內(nèi)容
    2022-03-03

最新評(píng)論