pytorch加載自己的圖片數(shù)據(jù)集的2種方法詳解
pytorch加載圖片數(shù)據(jù)集有兩種方法。
1.ImageFolder 適合于分類(lèi)數(shù)據(jù)集,并且每一個(gè)類(lèi)別的圖片在同一個(gè)文件夾, ImageFolder加載的數(shù)據(jù)集, 訓(xùn)練數(shù)據(jù)為文件件下的圖片, 訓(xùn)練標(biāo)簽是對(duì)應(yīng)的文件夾, 每個(gè)文件夾為一個(gè)類(lèi)別
導(dǎo)入ImageFolder()包 from torchvision.datasets import ImageFolder

在Flower_Orig_dataset文件夾下有flower_orig 和 sunflower這兩個(gè)文件夾, 這兩個(gè)文件夾下放著同一個(gè)類(lèi)別的圖片。 使用 ImageFolder 加載的圖片, 就會(huì)返回圖片信息和對(duì)應(yīng)的label信息, 但是label信息是根據(jù)文件夾給出的, 如flower_orig就是標(biāo)簽0, sunflower就是標(biāo)簽1。
ImageFolder 加載數(shù)據(jù)集
1. 導(dǎo)入包和設(shè)置transform
import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
transforms = transforms.Compose([
transforms.Resize(256), # 將圖片短邊縮放至256,長(zhǎng)寬比保持不變:
transforms.CenterCrop(224), #將圖片從中心切剪成3*224*224大小的圖片
transforms.ToTensor() #把圖片進(jìn)行歸一化,并把數(shù)據(jù)轉(zhuǎn)換成Tensor類(lèi)型
]) 2.加載數(shù)據(jù)集: 將分類(lèi)圖片的父目錄作為路徑傳遞給ImageFolder(), 并傳入transform。這樣就有了要加載的數(shù)據(jù)集, 之后就可以使用DataLoader加載數(shù)據(jù), 并構(gòu)建網(wǎng)絡(luò)訓(xùn)練。
path = r'D:\數(shù)據(jù)集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
images, labels = data
print(images.shape)
print(labels.shape)
break使用pytorch提供的Dataset類(lèi)創(chuàng)建自己的數(shù)據(jù)集。
具體步驟:
1. 首先要有一個(gè)txt文件, 這個(gè)文件格式是: 圖片路徑 標(biāo)簽. 這樣的格式, 所以使用os庫(kù), 遍歷自己的圖片名, 并把標(biāo)簽和圖片路徑寫(xiě)入txt文件。
2. 有了這個(gè)txt文件, 我們就可以在類(lèi)里面構(gòu)造我們的數(shù)據(jù)集.
2.1 把圖片路徑和圖片標(biāo)簽分割開(kāi), 有兩個(gè)列表, 一個(gè)列表是圖片路徑名, 一個(gè)列表是標(biāo)簽號(hào), 有一點(diǎn)就是第 i 個(gè)圖片列表和 第 i 個(gè)標(biāo)簽是對(duì)應(yīng)的
3. 重寫(xiě)__len__方法 和 __getitem__方法
3.1 getitem方法中, 獲得對(duì)應(yīng)的圖片路徑,并用PIL庫(kù)讀取文件把圖片transfrom后, 在getitem函數(shù)中返回讀取的圖片和標(biāo)簽即可
4.就可以構(gòu)建數(shù)據(jù)集實(shí)例和加載數(shù)據(jù)集.
定義一個(gè)用來(lái)生成[ 圖片路徑 標(biāo)簽] 這樣的txt文件函數(shù)
def make_txt(root, file_name, label):
path = os.path.join(root, file_name)
data = os.listdir(path)
f = open(path+'\\'+'f.txt', 'w')
for line in data:
f.write(line+' '+str(label)+'\n')
f.close()
#調(diào)用函數(shù)生成兩個(gè)文件夾下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)將連個(gè)txt文件合并成一個(gè)txt文件,表示數(shù)據(jù)集所有的圖片和標(biāo)簽
def link_txt(file1, file2):
txt_list = []
path = r'D:\數(shù)據(jù)集\Flower_Orig_dataset\data.txt'
f = open(path, 'a')
f1 = open(file1, 'r')
data1 = f1.readlines()
for line in data1:
txt_list.append(line)
f2 = open(file2, 'r')
data2 = f2.readlines()
for line in data2:
txt_list.append(line)
for line in txt_list:
f.write(line)
f.close()
f1.close()
f2.close()
#調(diào)用函數(shù), 將兩個(gè)文件夾下的txt文件合并
file1 = r'D:\數(shù)據(jù)集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\數(shù)據(jù)集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)現(xiàn)在我們已經(jīng)有了我們制作數(shù)據(jù)集所需要的txt文件, 接下來(lái)要做的即使繼承Dataset類(lèi), 來(lái)構(gòu)建自己的數(shù)據(jù)集 , 別忘了前面說(shuō)的 構(gòu)建數(shù)據(jù)集步驟, 在__getitem__函數(shù)中, 需要拿到圖片路徑和標(biāo)簽, 并且用PIL庫(kù)方法讀取圖片,對(duì)圖片進(jìn)行transform轉(zhuǎn)換后,返回圖片信息和標(biāo)簽信息
Dataset加載數(shù)據(jù)集
我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt文件, 拿到txt文件后, 先讀取txt文件, 之后遍歷txt文件中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱(chēng), 后半部分是圖片標(biāo)簽, 當(dāng)圖片名稱(chēng)和根目錄結(jié)合,就得到了我們的圖片路徑
class MyDataset(Dataset):
def __init__(self, img_path, transform=None):
super(MyDataset, self).__init__()
self.root = img_path
self.txt_root = self.root + 'data.txt'
f = open(self.txt_root, 'r')
data = f.readlines()
imgs = []
labels = []
for line in data:
line = line.rstrip()
word = line.split()
imgs.append(os.path.join(self.root, word[1], word[0]))
labels.append(word[1])
self.img = imgs
self.label = labels
self.transform = transform
def __len__(self):
return len(self.label)
def __getitem__(self, item):
img = self.img[item]
label = self.label[item]
img = Image.open(img).convert('RGB')
#此時(shí)img是PIL.Image類(lèi)型 label是str類(lèi)型
if transforms is not None:
img = self.transform(img)
label = np.array(label).astype(np.int64)
label = torch.from_numpy(label)
return img, label加載我們的數(shù)據(jù)集:
path = r'D:\數(shù)據(jù)集\Flower_Orig_dataset' dataset = MyDataset(path, transform=transform) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下來(lái)我們就可以構(gòu)建我們的網(wǎng)絡(luò)架構(gòu):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3,16,3)
self.maxpool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(16,5,3)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(55*55*5, 1200)
self.fc2 = nn.Linear(1200,64)
self.fc3 = nn.Linear(64,2)
def forward(self,x):
x = self.maxpool(self.relu(self.conv1(x))) #113
x = self.maxpool(self.relu(self.conv2(x))) #55
x = x.view(-1, self.num_flat_features(x))
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
訓(xùn)練我們的網(wǎng)絡(luò):
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 10
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(data_loader):
images, label = data
out = model(images)
loss = criterion(out, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if(i+1)%10 == 0:
print('[%d %5d] loss: %.3f'%(epoch+1, i+1, running_loss/100))
running_loss = 0.0
print('finished train')保存網(wǎng)絡(luò)模型(這里不止是保存參數(shù),還保存了網(wǎng)絡(luò)結(jié)構(gòu))
#保存模型
torch.save(net, 'model_name.pth') #保存的是模型, 不止是w和b權(quán)重值
# 讀取模型
model = torch.load('model_name.pth')總結(jié)
到此這篇關(guān)于pytorch加載自己的圖片數(shù)據(jù)集的2種方法的文章就介紹到這了,更多相關(guān)pytorch加載圖片數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
在Pycharm中修改文件默認(rèn)打開(kāi)方式的方法
今天小編就為大家分享一篇在Pycharm中修改文件默認(rèn)打開(kāi)方式的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01
tensorflow 分類(lèi)損失函數(shù)使用小記
這篇文章主要介紹了tensorflow 分類(lèi)損失函數(shù)使用小記,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-02-02
Python Web框架Flask信號(hào)機(jī)制(signals)介紹
這篇文章主要介紹了Python Web框架Flask信號(hào)機(jī)制(signals)介紹,本文介紹Flask的信號(hào)機(jī)制,講述信號(hào)的用途,并給出創(chuàng)建信號(hào)、訂閱信號(hào)、發(fā)送信號(hào)的方法,需要的朋友可以參考下2015-01-01
Python Numpy中數(shù)據(jù)的常用保存與讀取方法
這篇文章主要介紹了Python Numpy中數(shù)據(jù)的常用保存與讀取方法,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-04-04

