pytorch cnn 識(shí)別手寫的字實(shí)現(xiàn)自建圖片數(shù)據(jù)
本文主要介紹了pytorch cnn 識(shí)別手寫的字實(shí)現(xiàn)自建圖片數(shù)據(jù),分享給大家,具體如下:
# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001 # learning rate
root = "./mnist/raw/"
def default_loader(path):
# return Image.open(path).convert('RGB')
return Image.open(path)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = Image.fromarray(np.array(img), mode='L')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out(x)
return output, x # return x for visualization
cnn = CNN()
print(cnn) # net architecture
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
b_x = Variable(x) # batch x
b_y = Variable(y) # batch y
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
cnn.eval()
eval_loss = 0.
eval_acc = 0.
for i, (tx, ty) in enumerate(test_loader):
t_x = Variable(tx)
t_y = Variable(ty)
output = cnn(t_x)[0]
loss = loss_func(output, t_y)
eval_loss += loss.data[0]
pred = torch.max(output, 1)[1]
num_correct = (pred == t_y).sum()
eval_acc += float(num_correct.data[0])
acc_rate = eval_acc / float(len(test_data))
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
圖片和label 見上一篇文章《pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt》
結(jié)果如下:

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
- Pytorch 使用CNN圖像分類的實(shí)現(xiàn)
- pytorch實(shí)現(xiàn)textCNN的具體操作
- Pytorch mask-rcnn 實(shí)現(xiàn)細(xì)節(jié)分享
- 在Pytorch中使用Mask R-CNN進(jìn)行實(shí)例分割操作
- pytorch實(shí)現(xiàn)CNN卷積神經(jīng)網(wǎng)絡(luò)
- pytorch實(shí)現(xiàn)用CNN和LSTM對(duì)文本進(jìn)行分類方式
- 用Pytorch訓(xùn)練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- pytorch + visdom CNN處理自建圖片數(shù)據(jù)集的方法
- PyTorch CNN實(shí)戰(zhàn)之MNIST手寫數(shù)字識(shí)別示例
- PyTorch上實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)CNN的方法
- 基于PyTorch實(shí)現(xiàn)一個(gè)簡(jiǎn)單的CNN圖像分類器
相關(guān)文章
Python深度學(xué)習(xí)albumentations數(shù)據(jù)增強(qiáng)庫(kù)
下面開始albumenations的正式介紹,在這里我強(qiáng)烈建議英語(yǔ)基礎(chǔ)還好的讀者去官方網(wǎng)站跟著教程一步步學(xué)習(xí),而這里的內(nèi)容主要是我自己的一個(gè)總結(jié)以及方便英語(yǔ)能力較弱的讀者學(xué)習(xí)2021-09-09
Python測(cè)試WebService接口的實(shí)現(xiàn)示例
webService接口是走soap協(xié)議通過(guò)http傳輸,請(qǐng)求報(bào)文和返回報(bào)文都是xml格式的,本文主要介紹了Python測(cè)試WebService接口,具有一定的參考價(jià)值,感興趣的可以了解一下2024-03-03
用Python selenium實(shí)現(xiàn)淘寶搶單機(jī)器人
今天給大家?guī)?lái)的是關(guān)于Python實(shí)戰(zhàn)的相關(guān)知識(shí),文章圍繞著用Python selenium實(shí)現(xiàn)淘寶搶單機(jī)器人展開,文中有非常詳細(xì)的介紹及代碼示例,需要的朋友可以參考下2021-06-06
python cx_Oracle的基礎(chǔ)使用方法(連接和增刪改查)
這篇文章主要給大家介紹了關(guān)于python cx_Oracle的基礎(chǔ)使用方法,其中包括連接、增刪改查等基本操作,并給大家分享了python 連接Oracle 亂碼問(wèn)題的解決方法,需要的朋友可以參考借鑒,下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧。2017-11-11
python按照多個(gè)字符對(duì)字符串進(jìn)行分割的方法
這篇文章主要介紹了python按照多個(gè)字符對(duì)字符串進(jìn)行分割的方法,涉及Python中正則表達(dá)式匹配的技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-03-03
Python讀寫操作csv和excle文件代碼實(shí)例
這篇文章主要介紹了python讀寫操作csv和excle文件代碼實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-03-03

