欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch自定義CNN網(wǎng)絡(luò)實(shí)現(xiàn)貓狗分類(lèi)詳解過(guò)程

 更新時(shí)間:2022年12月08日 15:04:30   作者:專(zhuān)業(yè)女神殺手  
PyTorch是一個(gè)開(kāi)源的Python機(jī)器學(xué)習(xí)庫(kù),基于Torch,用于自然語(yǔ)言處理等應(yīng)用程序。它不僅能夠?qū)崿F(xiàn)強(qiáng)大的GPU加速,同時(shí)還支持動(dòng)態(tài)神經(jīng)網(wǎng)絡(luò)。本文將介紹PyTorch自定義CNN網(wǎng)絡(luò)實(shí)現(xiàn)貓狗分類(lèi),感興趣的可以學(xué)習(xí)一下

前言

數(shù)據(jù)集下載地址:

鏈接: https://pan.baidu.com/s/17aglKyKFvMvcug0xrOqJdQ?pwd=6i7m 

Dogs vs. Cats(貓狗大戰(zhàn))來(lái)源Kaggle上的一個(gè)競(jìng)賽題,任務(wù)為給定一個(gè)數(shù)據(jù)集,設(shè)計(jì)一種算法中的貓狗圖片進(jìn)行判別。

數(shù)據(jù)集包括25000張帶標(biāo)簽的訓(xùn)練集圖片,貓和狗各125000張,標(biāo)簽都是以cat or dog命名的。圖像為RGB格式j(luò)pg圖片,size不一樣。截圖如下:

一. 數(shù)據(jù)預(yù)處理

pytorch的數(shù)據(jù)預(yù)處理部分要寫(xiě)成一個(gè)類(lèi),這個(gè)類(lèi)繼承Dataset類(lèi),并必須要實(shí)現(xiàn)三個(gè)函數(shù)。

from torch.utils.data import DataLoader,Dataset
from torchvision import transforms as T
import matplotlib.pyplot as plt
import os
from PIL import Image
class DogCat(Dataset):
    def __init__(self, root, transforms=None, train=True):
        imgs = [os.path.join(root,img) for img in os.listdir(root)]
        imgs_num = len(imgs)
        if train:
            self.imgs = imgs[:int(0.7 * imgs_num)]
        else:
            self.imgs = imgs[int(0.3 * imgs_num):]
        if transforms is None:
            normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
            ])
        else:
            self.transforms = transforms
    def __getitem__(self, index):
        img_path = self.imgs[index]
        # dog label : 1           cat label : 0
        label = 1 if "dog" in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data,label
    def __len__(self):
        return len(self.imgs)

__init__為構(gòu)造函數(shù),我這里用力定義數(shù)據(jù)路徑,數(shù)據(jù)集劃分,transforms。

__getitem__為迭代函數(shù),用來(lái)return單個(gè)數(shù)據(jù)的data和label。

__len__返回?cái)?shù)據(jù)集的長(zhǎng)度。

二. 定義網(wǎng)絡(luò)

在這個(gè)例子中,我們用一個(gè)簡(jiǎn)單的4層卷積,2層全連接,最后跟一個(gè)sigmoid輸出二分類(lèi)的概率的CNN網(wǎng)絡(luò)。

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.conv4 = nn.Conv2d(128, 128, 3)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        # 12*12 for size(224,224)    7*7 for size(150,150)
        self.fc1 = nn.Linear(128*12*12, 512)
        self.fc2 = nn.Linear(512, 1)
    def forward(self, x):
        in_size = x.size(0)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.max_pool(x)
        # 展開(kāi)
        x = x.view(in_size, -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

pytorch定義網(wǎng)絡(luò)時(shí),必須實(shí)現(xiàn)兩個(gè)函數(shù),構(gòu)造函數(shù)主要定義一些網(wǎng)絡(luò)塊,forward函數(shù)實(shí)現(xiàn)前向推理過(guò)程。且在后續(xù)代碼中,如果定義對(duì)象model: ConvNet和數(shù)據(jù)image,可以直接通過(guò)model(image)來(lái)調(diào)用froward函數(shù)(python真的很神奇,C++出身的我理解這些騷操作好難)

三. 訓(xùn)練模型

數(shù)據(jù)準(zhǔn)備好了,模型網(wǎng)絡(luò)定義好了,下一步當(dāng)然是訓(xùn)練權(quán)重了。

import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from dataset import DogCat
from network import ConvNet
from draw import draw_acc,draw_loss
train_data_root = "/home/elvis/workfile/dataset/dataset_kaggledogvscat/train"
batch_size = 256
# 1. prepare dataset
train_data = DogCat(train_data_root, train=True)
val_data = DogCat(train_data_root, train=False)
train_dataloader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(val_data,batch_size=batch_size,shuffle=True)
# 2. load model
model = ConvNet()
if torch.cuda.is_available():
    model.cuda()
# 3. prepare super parameters
criterion = nn.BCELoss()
learning_rate = 1e-3
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 4. train
train_loss_epoch = []
train_acc_epoch = []
val_loss_epoch = []
val_acc_epoch = []
for epoch in range(1, 10):
    model.train()
    train_loss = 0;
    train_acc = 0;
    for batch_idx, (data, target) in enumerate(train_dataloader):
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda().float().unsqueeze(-1)
        else:
            data, target = data, target.float().unsqueeze(-1)
        optimizer.zero_grad()
        output = model(data)
        # print(output)
        loss = criterion(output, target)
        train_loss += loss.item();
        pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in output]).cuda();
        train_acc += pred.eq(target.long()).sum().item();
        loss.backward()
        optimizer.step()
        if(batch_idx+1)%10 == 0: 
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx+1) * len(data), len(train_dataloader.dataset),
                100. * (batch_idx+1) / len(train_dataloader), loss.item()))
    train_loss_epoch.append(train_loss / len(train_dataloader));
    train_acc_epoch.append(train_acc / len(train_dataloader.dataset));
    print('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(train_loss / len(train_dataloader), train_acc, len(train_dataloader.dataset),
                                                                                    100. * train_acc / len(train_dataloader.dataset)));
    # val
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(val_dataloader):
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda().float().unsqueeze(-1)
            else:
                data, target = data, target.float().unsqueeze(-1)
            output = model(data)
            # print(output)
            test_loss += criterion(output, target).item(); #每個(gè)批次平均,一個(gè)epoch里所有批次求和
            pred = torch.tensor([[1] if num[0] >= 0.5 else [0] for num in output]).cuda()
            correct += pred.eq(target.long()).sum().item()
    print('Valid set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss/len(val_dataloader), correct, len(val_dataloader.dataset),
                                                                                    100. * correct / len(val_dataloader.dataset)));
    val_loss_epoch.append(test_loss / len(val_dataloader));
    val_acc_epoch.append(correct / len(val_dataloader.dataset));
    # Save model
    val_acc_rate = correct / len(val_dataloader.dataset);
    save = True
    best = "best.pt"
    last = "last.pt"
    if save:
        # Save last, best and delete
        torch.save(model.state_dict(), last)
        if val_acc_rate == max(val_acc_epoch):
            torch.save(model.state_dict(), best)
            print("save epoch {} model".format(epoch))
# 5. drawing
draw_loss(train_loss_epoch, val_loss_epoch)
draw_acc(train_acc_epoch,val_acc_epoch)

第一步,準(zhǔn)備數(shù)據(jù)。先用我們之前定義的DogCat類(lèi)來(lái)加載數(shù)據(jù),但這個(gè)類(lèi)繼承自dataset,是加載一條數(shù)據(jù)的。如果要批量加載數(shù)據(jù),還要用pytorch內(nèi)部的另一個(gè)類(lèi)DataLoader,然后在構(gòu)造函數(shù)里傳入batchsize就可以批量加載數(shù)據(jù)了。注意這里的類(lèi)對(duì)象實(shí)際是一個(gè)生成器,后續(xù)通過(guò)循環(huán)就可以一直批量的去取數(shù)據(jù)了。

第二步,定義模型對(duì)象,有用顯卡就把模型放在顯卡上,沒(méi)有的話就用cpu跑。

第三步,定義一些超參數(shù)。因?yàn)槭嵌诸?lèi),網(wǎng)絡(luò)最后一層為sigmoid輸出類(lèi)別的概率值,所以選用二分類(lèi)交叉熵?fù)p失函數(shù)。再設(shè)置一下學(xué)習(xí)率和優(yōu)化器。

第四步,訓(xùn)練n個(gè)epoch。在每一個(gè)epoch里計(jì)算訓(xùn)練集準(zhǔn)去率,驗(yàn)證集準(zhǔn)確率,并保存模型。

最后結(jié)果像這樣

有條件的可以多訓(xùn)練幾個(gè)epoch試試。

到此這篇關(guān)于Pytorch自定義CNN網(wǎng)絡(luò)實(shí)現(xiàn)貓狗分類(lèi)詳解過(guò)程的文章就介紹到這了,更多相關(guān)Pytorch貓狗分類(lèi)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • django 實(shí)現(xiàn)編寫(xiě)控制登錄和訪問(wèn)權(quán)限控制的中間件方法

    django 實(shí)現(xiàn)編寫(xiě)控制登錄和訪問(wèn)權(quán)限控制的中間件方法

    今天小編就為大家分享一篇django 實(shí)現(xiàn)編寫(xiě)控制登錄和訪問(wèn)權(quán)限控制的中間件方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-01-01
  • python爬蟲(chóng)模擬瀏覽器訪問(wèn)-User-Agent過(guò)程解析

    python爬蟲(chóng)模擬瀏覽器訪問(wèn)-User-Agent過(guò)程解析

    這篇文章主要介紹了python爬蟲(chóng)模擬瀏覽器訪問(wèn)-User-Agent過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-12-12
  • python實(shí)現(xiàn)雙鏈表

    python實(shí)現(xiàn)雙鏈表

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)雙鏈表,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2022-05-05
  • python連接數(shù)據(jù)庫(kù)后通過(guò)占位符添加數(shù)據(jù)

    python連接數(shù)據(jù)庫(kù)后通過(guò)占位符添加數(shù)據(jù)

    在pymysql中支持對(duì)占位符的處理,開(kāi)發(fā)者需要在SQL中使用“%”定義占位符,在使用excute()方法執(zhí)行時(shí)對(duì)占位符的數(shù)據(jù)進(jìn)行填充即可,本文給大家介紹python連接數(shù)據(jù)庫(kù)后通過(guò)占位符添加數(shù)據(jù)的方法,需要的朋友參考下吧
    2021-12-12
  • Python中的Socket 與 ScoketServer 通信及遇到問(wèn)題解決方法

    Python中的Socket 與 ScoketServer 通信及遇到問(wèn)題解決方法

    Socket有一個(gè)緩沖區(qū),緩沖區(qū)是一個(gè)流,先進(jìn)先出,發(fā)送和取出的可自定義大小的,如果取出的數(shù)據(jù)未取完緩沖區(qū),則可能存在數(shù)據(jù)怠慢。本文通過(guò)實(shí)例代碼給大家介紹Python中的Socket 與 ScoketServer 通信及遇到問(wèn)題解決方法 ,需要的朋友參考下吧
    2019-04-04
  • python判斷給定的字符串是否是有效日期的方法

    python判斷給定的字符串是否是有效日期的方法

    這篇文章主要介紹了python判斷給定的字符串是否是有效日期的方法,涉及Python針對(duì)字符串與日期操作的相關(guān)技巧,需要的朋友可以參考下
    2015-05-05
  • Python解析json之ValueError: Expecting property name enclosed in double quotes: line 1 column 2(char 1)

    Python解析json之ValueError: Expecting property name enclosed in

    這篇文章主要給大家介紹了關(guān)于Python解析json報(bào)錯(cuò):ValueError: Expecting property name enclosed in double quotes: line 1 column 2(char 1)的解決方法,文中介紹的非常詳細(xì),需要的朋友們可以參考借鑒,下面來(lái)一起看看吧。
    2017-07-07
  • python實(shí)現(xiàn)的jpg格式圖片修復(fù)代碼

    python實(shí)現(xiàn)的jpg格式圖片修復(fù)代碼

    這篇文章主要介紹了python實(shí)現(xiàn)的jpg格式圖片修復(fù)代碼,本文直接給出實(shí)現(xiàn)代碼,需要的朋友可以參考下
    2015-04-04
  • Python技巧之四種多線程應(yīng)用分享

    Python技巧之四種多線程應(yīng)用分享

    這篇文章主要介紹了Python中多線程的所有方式,包括使用threading模塊、使用concurrent.futures模塊、使用multiprocessing模塊以及使用asyncio模塊,希望對(duì)大家有所幫助
    2023-05-05
  • Python 使用 docopt 解析json參數(shù)文件過(guò)程講解

    Python 使用 docopt 解析json參數(shù)文件過(guò)程講解

    這篇文章主要介紹了Python 使用 docopt 解析json參數(shù)文件過(guò)程講解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08

最新評(píng)論