欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch 如何將CIFAR100數(shù)據(jù)按類標歸類保存

 更新時間:2021年05月10日 09:12:51   作者:Xie_learning  
這篇文章主要介紹了PyTorch 將CIFAR100數(shù)據(jù)按類標歸類保存的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

few-shot learning的采樣

Few-shot learning 基于任務對模型進行訓練,在N-way-K-shot中,一個任務中的meta-training中含有N類,每一類抽取K個樣本構成support set, query set則是在剛才抽取的N類剩余的樣本中sample一定數(shù)量的樣本(可以是均勻采樣,也可以是不均勻采樣)。

對數(shù)據(jù)按類標歸類

針對上述情況,我們需要使用不同類別放置在不同文件夾的數(shù)據(jù)集。但有時,數(shù)據(jù)并沒有按類放置,這時就需要對數(shù)據(jù)進行處理。

下面以CIFAR100為列(不含N-way-k-shot的采樣):

import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
    character = [[] for i in range(100)]
    train_set = tv.datasets.CIFAR100(root, train=True, download=True)
    test_set = tv.datasets.CIFAR100(root, train=False, download=True)
    dataset = []
    for (X, Y) in zip(train_set.train_data, train_set.train_labels):  # 將train_set的數(shù)據(jù)和label讀入列表
        dataset.append(list((X, Y)))
    for (X, Y) in zip(test_set.test_data, test_set.test_labels):  # 將test_set的數(shù)據(jù)和label讀入列表
        dataset.append(list((X, Y)))
    for X, Y in dataset:
        character[Y].append(X)  # 32*32*3
    character = np.array(character)
    character = torch.from_numpy(character)
    # 按類打亂
    np.random.seed(6)
    shuffle_class = np.arange(len(character))
    np.random.shuffle(shuffle_class)
    character = character[shuffle_class]
    # shape = self.character.shape
    # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3])  # 將數(shù)據(jù)轉(zhuǎn)成channel在前
    meta_training, meta_validation, meta_testing = \
    character[:64], character[64:80], character[80:]  # meta_training : meta_validation : Meta_testing = 64類:16類:20類
    dataset = []  # 釋放內(nèi)存
    character = []
    os.mkdir(os.path.join(root, 'meta_training'))
    for i, per_class in enumerate(meta_training):
        character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_validation'))
    for i, per_class in enumerate(meta_validation):
        character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
    os.mkdir(os.path.join(root, 'meta_testing'))
    for i, per_class in enumerate(meta_testing):
        character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
        os.mkdir(character_path)
        for j, img in enumerate(per_class):
            img_path = character_path + '/' + str(j) + ".jpg"
            io.imsave(img_path, img)
if __name__ == '__main__':
    root = '/home/xie/文檔/datasets/cifar_100'
    Cifar100(root)
    print("-----------------")

補充:使用Pytorch對數(shù)據(jù)集CIFAR-10進行分類

主要是以下幾個步驟:

1、下載并預處理數(shù)據(jù)集

2、定義網(wǎng)絡結(jié)構

3、定義損失函數(shù)和優(yōu)化器

4、訓練網(wǎng)絡并更新參數(shù)

5、測試網(wǎng)絡效果

#數(shù)據(jù)加載和預處理
#使用CIFAR-10數(shù)據(jù)進行分類實驗
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor轉(zhuǎn)成Image,方便可視化
 
#定義對數(shù)據(jù)的預處理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  #歸一化
])
 
#訓練集
trainset = tv.datasets.CIFAR10(
    root = './data/',
    train = True,
    download = True,
    transform = transform
)
 
trainloader = t.utils.data.DataLoader(
    trainset,
    batch_size = 4,
    shuffle = True,
    num_workers = 2,
)
 
#測試集
testset = tv.datasets.CIFAR10(
    root = './data/',
    train = False,
    download = True,
    transform = transform,
)
testloader = t.utils.data.DataLoader(
    testset,
    batch_size = 4,
    shuffle = False,
    num_workers = 2,
)
 
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

初次下載需要一些時間,運行結(jié)束后,顯示如下:

import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time()#計時
#定義網(wǎng)絡結(jié)構
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
        
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        
        x = x.view(x.size()[0],-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
net = Net()
print(net)

顯示net結(jié)構如下:

#定義優(yōu)化和損失
loss_func = nn.CrossEntropyLoss()  #交叉熵損失函數(shù)
optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9)
 
#訓練網(wǎng)絡
for epoch in range(2):
    running_loss = 0
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
       
        outputs = net(inputs)
        loss = loss_func(outputs,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss +=loss.item()
        if i%2000 ==1999:
            print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))
            running_loss = 0.0
end = time.time()
time_using = end - start
print('finish training')
print('time:',time_using)

結(jié)果如下:

下一步進行使用測試集進行網(wǎng)絡測試:

#測試網(wǎng)絡
correct = 0 #定義的預測正確的圖片數(shù)
total = 0#總共圖片個數(shù)
with t.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predict = t.max(outputs,1)
        total += labels.size(0)
        correct += (predict == labels).sum()
print('測試集中的準確率為:%d%%'%(100*correct/total))

結(jié)果如下:

簡單的網(wǎng)絡訓練確實要比10%的比例高一點:)

在GPU中訓練:

#在GPU中訓練
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
 
net.to(device)
images = images.to(device)
labels = labels.to(device)
 
output = net(images)
loss = loss_func(output,labels)
 
loss

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。

相關文章

  • django2用iframe標簽完成網(wǎng)頁內(nèi)嵌播放b站視頻功能

    django2用iframe標簽完成網(wǎng)頁內(nèi)嵌播放b站視頻功能

    這篇文章主要介紹了django2 用iframe標簽完成 網(wǎng)頁內(nèi)嵌播放b站視頻功能,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-06-06
  • python 多線程爬取壁紙網(wǎng)站的示例

    python 多線程爬取壁紙網(wǎng)站的示例

    這篇文章主要介紹了python 多線程爬取壁紙網(wǎng)站的示例,幫助大家更好的理解和學習使用python,感興趣的朋友可以了解下
    2021-02-02
  • Python實現(xiàn)數(shù)據(jù)庫編程方法詳解

    Python實現(xiàn)數(shù)據(jù)庫編程方法詳解

    這篇文章主要介紹了Python實現(xiàn)數(shù)據(jù)庫編程方法,較為詳細的總結(jié)了Python數(shù)據(jù)庫編程涉及的各種常用技巧與相關組件,需要的朋友可以參考下
    2015-06-06
  • Python如何讀寫字節(jié)數(shù)據(jù)

    Python如何讀寫字節(jié)數(shù)據(jù)

    這篇文章主要介紹了Python如何讀寫字節(jié)數(shù)據(jù),文中講解非常細致,代碼幫助大家更好的理解和學習,感興趣的朋友可以了解下
    2020-08-08
  • 如何解決Pycharm運行報錯No Python interpreter selected問題

    如何解決Pycharm運行報錯No Python interpreter selected

    這篇文章主要介紹了如何解決Pycharm運行時No Python interpreter selected問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2024-05-05
  • 使用Python腳本備份華為交換機的配置信息

    使用Python腳本備份華為交換機的配置信息

    在現(xiàn)代網(wǎng)絡管理中,備份交換機的配置信息是一項至關重要的任務,備份可以確保在交換機發(fā)生故障或配置錯誤時,能夠迅速恢復到之前的工作狀態(tài),本文將詳細介紹如何使用Python腳本備份華為交換機的配置信息,需要的朋友可以參考下
    2024-06-06
  • python類繼承與子類實例初始化用法分析

    python類繼承與子類實例初始化用法分析

    這篇文章主要介紹了python類繼承與子類實例初始化用法,實例分析了Python類的使用技巧,具有一定參考借鑒價值,需要的朋友可以參考下
    2015-04-04
  • Python標準庫之os模塊詳解

    Python標準庫之os模塊詳解

    Python的os模塊是用于與操作系統(tǒng)進行交互的模塊,它提供了許多函數(shù)和方法來執(zhí)行文件和目錄操作、進程管理、環(huán)境變量訪問等,本文詳細介紹了Python標準庫中os模塊,感興趣的同學跟著小編一起來看看吧
    2023-08-08
  • python調(diào)用函數(shù)、類和文件操作簡單實例總結(jié)

    python調(diào)用函數(shù)、類和文件操作簡單實例總結(jié)

    這篇文章主要介紹了python調(diào)用函數(shù)、類和文件操作,結(jié)合簡單實例形式總結(jié)分析了Python調(diào)用函數(shù)、類和文件操作的各種常見操作技巧,需要的朋友可以參考下
    2019-11-11
  • python高級搜索實現(xiàn)高效搜索GitHub資源

    python高級搜索實現(xiàn)高效搜索GitHub資源

    這篇文章主要為大家介紹了python高級搜索來高效搜索GitHub,從而高效獲取所需資源,有需要的朋友可以借鑒參考下,希望能夠有所幫助
    2021-11-11

最新評論