利用pytorch實現(xiàn)對CIFAR-10數(shù)據(jù)集的分類
步驟如下:
1.使用torchvision加載并預(yù)處理CIFAR-10數(shù)據(jù)集、
2.定義網(wǎng)絡(luò)
3.定義損失函數(shù)和優(yōu)化器
4.訓(xùn)練網(wǎng)絡(luò)并更新網(wǎng)絡(luò)參數(shù)
5.測試網(wǎng)絡(luò)
運行環(huán)境:
windows+python3.6.3+pycharm+pytorch0.3.0 import torchvision as tv import torchvision.transforms as transforms import torch as t from torchvision.transforms import ToPILImage show=ToPILImage() #把Tensor轉(zhuǎn)成Image,方便可視化 import matplotlib.pyplot as plt import torchvision import numpy as np ###############數(shù)據(jù)加載與預(yù)處理 transform = transforms.Compose([transforms.ToTensor(),#轉(zhuǎn)為tensor transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#歸一化 ]) #訓(xùn)練集 trainset=tv.datasets.CIFAR10(root='/python projects/test/data/', train=True, download=True, transform=transform) trainloader=t.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0) #測試集 testset=tv.datasets.CIFAR10(root='/python projects/test/data/', train=False, download=True, transform=transform) testloader=t.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=0) classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') (data,label)=trainset[100] print(classes[label]) show((data+1)/2).resize((100,100)) # dataiter=iter(trainloader) # images,labels=dataiter.next() # print(''.join('11%s'%classes[labels[j]] for j in range(4))) # show(tv.utils.make_grid(images+1)/2).resize((400,100)) def imshow(img): img = img / 2 + 0.5 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) dataiter = iter(trainloader) images, labels = dataiter.next() print(images.size()) imshow(torchvision.utils.make_grid(images)) plt.show()#關(guān)掉圖片才能往后繼續(xù)算 #########################定義網(wǎng)絡(luò) import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1=nn.Conv2d(3,6,5) self.conv2=nn.Conv2d(6,16,5) self.fc1=nn.Linear(16*5*5,120) self.fc2=nn.Linear(120,84) self.fc3=nn.Linear(84,10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)),2) x = F.max_pool2d(F.relu(self.conv2(x)),2) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net=Net() print(net) #############定義損失函數(shù)和優(yōu)化器 from torch import optim criterion=nn.CrossEntropyLoss() optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9) ##############訓(xùn)練網(wǎng)絡(luò) from torch.autograd import Variable import time start_time = time.time() for epoch in range(2): running_loss=0.0 for i,data in enumerate(trainloader,0): #輸入數(shù)據(jù) inputs,labels=data inputs,labels=Variable(inputs),Variable(labels) #梯度清零 optimizer.zero_grad() outputs=net(inputs) loss=criterion(outputs,labels) loss.backward() #更新參數(shù) optimizer.step() # 打印log running_loss += loss.data[0] if i % 2000 == 1999: print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('finished training') end_time = time.time() print("Spend time:", end_time - start_time)
以上這篇利用pytorch實現(xiàn)對CIFAR-10數(shù)據(jù)集的分類就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pandas之Dropna濾除缺失數(shù)據(jù)的實現(xiàn)方法
這篇文章主要介紹了Pandas之Dropna濾除缺失數(shù)據(jù)的實現(xiàn)方法,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-06-06使用Pandas?實現(xiàn)MySQL日期函數(shù)的解決方法
這篇文章主要介紹了用Pandas?實現(xiàn)MySQL日期函數(shù)的效果,Python是很靈活的語言,達成同一個目標或有多種途徑,我提供的只是其中一種解決方法,需要的朋友可以參考下2023-02-02解決使用python print打印函數(shù)返回值多一個None的問題
這篇文章主要介紹了解決使用python print打印函數(shù)返回值多一個None的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04