python深度學(xué)習(xí)之多標(biāo)簽分類(lèi)器及pytorch實(shí)現(xiàn)源碼
多標(biāo)簽分類(lèi)器
多標(biāo)簽分類(lèi)任務(wù)與多分類(lèi)任務(wù)有所不同,多分類(lèi)任務(wù)是將一個(gè)實(shí)例分到某個(gè)類(lèi)別中,多標(biāo)簽分類(lèi)任務(wù)是將某個(gè)實(shí)例分到多個(gè)類(lèi)別中。多標(biāo)簽分類(lèi)任務(wù)有有兩大特點(diǎn):
- 類(lèi)標(biāo)數(shù)量不確定,有些樣本可能只有一個(gè)類(lèi)標(biāo),有些樣本的類(lèi)標(biāo)可能高達(dá)幾十甚至上百個(gè)
- 類(lèi)標(biāo)之間相互依賴,例如包含藍(lán)天類(lèi)標(biāo)的樣本很大概率上包含白云
如下圖所示,即為一個(gè)多標(biāo)簽分類(lèi)學(xué)習(xí)的一個(gè)例子,一張圖片里有多個(gè)類(lèi)別,房子,樹(shù),云等,深度學(xué)習(xí)模型需要將其一一分類(lèi)識(shí)別出來(lái)。
多標(biāo)簽分類(lèi)器損失函數(shù)
代碼實(shí)現(xiàn)
針對(duì)圖像的多標(biāo)簽分類(lèi)器pytorch的簡(jiǎn)化代碼實(shí)現(xiàn)如下所示。因?yàn)閳D像的多標(biāo)簽分類(lèi)器的數(shù)據(jù)集比較難獲取,所以可以通過(guò)對(duì)mnist數(shù)據(jù)集中的每個(gè)圖片打上特定的多標(biāo)簽,例如類(lèi)別1的多標(biāo)簽可以為[1,1,0,1,0,1,0,0,1],然后再利用重新打標(biāo)后的數(shù)據(jù)集訓(xùn)練出一個(gè)mnist的多標(biāo)簽分類(lèi)器。
from torchvision import datasets, transforms from torch.utils.data import DataLoader, Dataset import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import os class CNN(nn.Module): def __init__(self): super().__init__() self.Sq1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28) nn.ReLU(), nn.MaxPool2d(kernel_size=2), # (16, 14, 14) ) self.Sq2 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14) nn.ReLU(), nn.MaxPool2d(2), # (32, 7, 7) ) self.out = nn.Linear(32 * 7 * 7, 100) def forward(self, x): x = self.Sq1(x) x = self.Sq2(x) x = x.view(x.size(0), -1) x = self.out(x) ## Sigmoid activation output = F.sigmoid(x) # 1/(1+e**(-x)) return output def loss_fn(pred, target): return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum() def multilabel_generate(label): Y1 = F.one_hot(label, num_classes = 100) Y2 = F.one_hot(label+10, num_classes = 100) Y3 = F.one_hot(label+50, num_classes = 100) multilabel = Y1+Y2+Y3 return multilabel # def multilabel_generate(label): # multilabel_dict = {} # multi_list = [] # for i in range(label.shape[0]): # multi_list.append(multilabel_dict[label[i].item()]) # multilabel_tensor = torch.tensor(multi_list) # return multilabel def train(): epoches = 10 mnist_net = CNN() mnist_net.train() opitimizer = optim.SGD(mnist_net.parameters(), lr=0.002) mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= 128, shuffle=True) for epoch in range(epoches): loss = 0 for batch_X, batch_Y in train_loader: opitimizer.zero_grad() outputs = mnist_net(batch_X) loss = loss_fn(outputs, multilabel_generate(batch_Y)) / batch_X.shape[0] loss.backward() opitimizer.step() print(loss) if __name__ == '__main__': train()
以上就是python深度學(xué)習(xí)之多標(biāo)簽分類(lèi)器及pytorch源碼的詳細(xì)內(nèi)容,更多關(guān)于多標(biāo)簽分類(lèi)器pytorch源碼的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Pandas根據(jù)條件實(shí)現(xiàn)替換列中的值
在使用Pandas的Python中,DataFrame列中的值可以通過(guò)使用各種內(nèi)置函數(shù)根據(jù)條件進(jìn)行替換,本文主要來(lái)和大家討論在Pandas中用條件替換數(shù)據(jù)集列中的值的各種方法,希望對(duì)大家有所幫助2024-01-01python 經(jīng)緯度求兩點(diǎn)距離、三點(diǎn)面積操作
這篇文章主要介紹了python 經(jīng)緯度求兩點(diǎn)距離、三點(diǎn)面積操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-06-06Python字典生成式、集合生成式、生成器用法實(shí)例分析
這篇文章主要介紹了Python字典生成式、集合生成式、生成器用法,結(jié)合實(shí)例形式分析了Python字典生成式、集合生成式、生成器相關(guān)原理、使用技巧與操作注意事項(xiàng),需要的朋友可以參考下2020-01-01python使用OS模塊操作系統(tǒng)接口及常用功能詳解
os是?Python?標(biāo)準(zhǔn)庫(kù)中的一個(gè)模塊,提供了與操作系統(tǒng)交互的功能,在本節(jié)中,我們將介紹os模塊的一些常用功能,并通過(guò)實(shí)例代碼詳細(xì)講解每個(gè)知識(shí)點(diǎn)2023-06-06Python實(shí)現(xiàn)破解猜數(shù)游戲算法示例
這篇文章主要介紹了Python實(shí)現(xiàn)破解猜數(shù)游戲算法,簡(jiǎn)單描述了猜數(shù)游戲的原理,并結(jié)合具體實(shí)例形式分析了Python破解猜數(shù)游戲的相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2017-09-09Python爬蟲(chóng)入門(mén)案例之爬取去哪兒旅游景點(diǎn)攻略以及可視化分析
讀萬(wàn)卷書(shū)不如行萬(wàn)里路,學(xué)的扎不扎實(shí)要通過(guò)實(shí)戰(zhàn)才能看出來(lái),本篇文章手把手帶你爬取去哪兒平臺(tái)的旅游景點(diǎn)攻略并進(jìn)行可視化分析,大家可以在過(guò)程中查缺補(bǔ)漏,看看自己掌握程度怎么樣2021-10-10Django-Rest-Framework 權(quán)限管理源碼淺析(小結(jié))
這篇文章主要介紹了Django-Rest-Framework 權(quán)限管理源碼淺析(小結(jié)),小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-11-11