Pytorch加載數(shù)據(jù)集的方式總結(jié)及補(bǔ)充
前言
在用Pytorch加載數(shù)據(jù)集時(shí),看GitHub上的代碼經(jīng)常會(huì)用到ImageFolder、DataLoader等一系列方法,而這些方法又是來自于torchvision、torch.utils.data。除加載數(shù)據(jù)集外,還有torchvision中的transforms對數(shù)據(jù)集預(yù)處理…等等等等。這個(gè)data,那個(gè)dataset…這一系列下來,不加注意的話實(shí)在有點(diǎn)打腦殼??磩e人的代碼加載數(shù)據(jù)集挺簡單,但是自己用的時(shí)候,尤其是加載自己所制作的數(shù)據(jù)集的時(shí)候,就會(huì)茫然無措。別無他法,抱著硬啃的心態(tài),查閱了其他博文,通過代碼實(shí)驗(yàn),終于是理清楚了思路。
Pytorch加載數(shù)據(jù)集可以分兩種大的情況:一、自己重寫定義; 二、用Pytorch自帶的類。第二種里面又有多種不同的方法(datasets、 ImageFolder等),但這些方法都有相同的處理規(guī)律。我理解的,無論是哪種情況,加載數(shù)據(jù)集都需要構(gòu)造數(shù)據(jù)加載器和數(shù)據(jù)裝載器(后者生成的是可迭代的數(shù)據(jù))。現(xiàn)將這兩種情況一一說明。
一、自己重寫定義(Dataset、DataLoader)
目前我們有自己制作的數(shù)據(jù)以及數(shù)據(jù)標(biāo)簽,但是有時(shí)候感覺不太適合直接用Pytorch自帶加載數(shù)據(jù)集的方法。我們可以自己來重寫定義一個(gè)類,這個(gè)類繼承于 torch.utils.data.Dataset,同時(shí)我們需要重寫這個(gè)類里面的兩個(gè)方法 _ getitem__ () 和__ len()__函數(shù)。
如下所示。這兩種方法如何構(gòu)造以及具體的細(xì)節(jié)可以查看其他的博客。len方法必須返回?cái)?shù)據(jù)的長度,getitem方法必須返回?cái)?shù)據(jù)以及標(biāo)簽。
import torch import numpy as np # 定義GetLoader類,繼承Dataset方法,并重寫__getitem__()和__len__()方法 class GetLoader(torch.utils.data.Dataset): # 初始化函數(shù),得到數(shù)據(jù) def __init__(self, data_root, data_label): self.data = data_root self.label = data_label # index是根據(jù)batchsize劃分?jǐn)?shù)據(jù)后得到的索引,最后將data和對應(yīng)的labels進(jìn)行一起返回 def __getitem__(self, index): data = self.data[index] labels = self.label[index] return data, labels # 該函數(shù)返回?cái)?shù)據(jù)大小長度,目的是DataLoader方便劃分,如果不知道大小,DataLoader會(huì)一臉懵逼 def __len__(self): return len(self.data) # 隨機(jī)生成數(shù)據(jù),大小為10 * 20列 source_data = np.random.rand(10, 20) # 隨機(jī)生成標(biāo)簽,大小為10 * 1列 source_label = np.random.randint(0,2,(10, 1)) # 通過GetLoader將數(shù)據(jù)進(jìn)行加載,返回Dataset對象,包含data和labels torch_data = GetLoader(source_data, source_label)
通過上述的程序,我們構(gòu)造了一個(gè)數(shù)據(jù)加載器torch_data,但是還是不能直接傳入網(wǎng)絡(luò)中。接下來需要構(gòu)造數(shù)據(jù)裝載器,產(chǎn)生可迭代的數(shù)據(jù),再傳入網(wǎng)絡(luò)中。DataLoader類完成這個(gè)工作。
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
參數(shù)解釋:
1.dataset : 加載torch.utils.data.Dataset對象數(shù)據(jù)
2.batch_size : 每個(gè)batch的大小,將我們的數(shù)據(jù)分批輸入到網(wǎng)絡(luò)中
3.shuffle : 是否對數(shù)據(jù)進(jìn)行打亂
4.drop_last : 是否對無法整除的最后一個(gè)datasize進(jìn)行丟棄
5.num_workers : 表示加載的時(shí)候子進(jìn)程數(shù)
結(jié)合我們自己定義的加載數(shù)據(jù)集類,可以如下使用。后面將data和label傳入我們定義的模型中。
... torch_data = GetLoader(source_data, source_label) from torch.utils.data import DataLoader datas = DataLoader(torch_data, batch_size = 4, shuffle = True, drop_last = False, num_workers = 2) for i, (data, label) in enumerate(datas): # i表示第幾個(gè)batch, data表示batch_size個(gè)原始的數(shù)據(jù),label代表batch_size個(gè)數(shù)據(jù)的標(biāo)簽 print("第 {} 個(gè)Batch \n{}".format(i, data))
二、用Pytorch自帶的類(ImageFolder、datasets、DataLoader)
2.1 加載自己的數(shù)據(jù)集
2.1.1 ImageFolder介紹
和第一種情況不一樣,我們不需要在代碼上自己定義數(shù)據(jù)集類了,而是將數(shù)據(jù)集按照一定的格式擺放,調(diào)用ImageFolder類即可。這種是在調(diào)用Pytorch內(nèi)部的API,所以我們自己的數(shù)據(jù)集得需要按照API內(nèi)部所規(guī)定的存放格式。torchvision.datasets.ImageFolder 要求數(shù)據(jù)集按照如下方式組織。根目錄 root 下存儲(chǔ)的是類別文件夾(如cat,dog),每個(gè)類別文件夾下存儲(chǔ)相應(yīng)類別的圖像(如xxx.png)
A generic data loader where the images are arranged in this way:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.pngroot/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
torchvision.datasets.ImageFolder有以下參數(shù):
dataset=torchvision.datasets.ImageFolder( root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
參數(shù)解釋:
1.root:根目錄,在root目錄下,應(yīng)該有不同類別的子文件夾;
|--data(root)
|--train
|--cat
|--dog
|--valid
|--cat
|--dog
2.transform:對圖片進(jìn)行預(yù)處理的操作,原始圖像作為一個(gè)輸入,返回的是transform變換后的圖片;
3.target_transform:對圖片類別進(jìn)行預(yù)處理的操作,輸入為 target,輸出對其的轉(zhuǎn)換。 如果不傳該參數(shù),即對target不做任何轉(zhuǎn)換,返回的順序索引 0,1, 2…
4.loader:表示數(shù)據(jù)集加載方式,通常默認(rèn)加載方式即可;
5.is_valid_file:獲取圖像文件的路徑并檢查該文件是否為有效文件的函數(shù)(用于檢查損壞文件)
作為torchvision.datasets.ImageFolder的返回,會(huì)有以下三種屬性:
(1)self.classes:用一個(gè) list 保存類別名稱
(2)self.class_to_idx:類別對應(yīng)的索引,與不做任何轉(zhuǎn)換返回的 target 對應(yīng)
(3)self.imgs:保存(img_path, class) tuple的list
以貓狗類別舉例,各屬性輸出如下所示:
print(dataset.classes) #根據(jù)分的文件夾的名字來確定的類別 print(dataset.class_to_idx) #按順序?yàn)檫@些類別定義索引為0,1... print(dataset.imgs) #返回從所有文件夾中得到的圖片的路徑以及其類別 ''' 輸出: ['cat', 'dog'] {'cat': 0, 'dog': 1} [('./data/train\\cat\\1.jpg', 0), ('./data/train\\cat\\2.jpg', 0), ('./data/train\\dog\\1.jpg', 1), ('./data/train\\dog\\2.jpg', 1)] '''
2.2.2 ImageFolder加載數(shù)據(jù)集完整例子
# 5. 將文件夾數(shù)據(jù)導(dǎo)入 train_loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle=True, num_workers = 2) # 6. 傳入網(wǎng)絡(luò)進(jìn)行訓(xùn)練 for epoch in range(epochs): train_bar = tqdm(train_loader, file = sys.stdout) for step, data in enumerate(train_bar): ...
和第一種情況自己重寫定義一樣,上述的代碼僅僅完成了數(shù)據(jù)加載器的定義。這樣是不能直接傳入網(wǎng)絡(luò)中進(jìn)行訓(xùn)練的,需要再構(gòu)造一個(gè)可迭代的數(shù)據(jù)裝載器。DataLoader類的使用方式上文中有詳細(xì)介紹。
# 5. 將文件夾數(shù)據(jù)導(dǎo)入 train_loader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle=True, num_workers = 2) # 6. 傳入網(wǎng)絡(luò)進(jìn)行訓(xùn)練 for epoch in range(epochs): train_bar = tqdm(train_loader, file = sys.stdout) for step, data in enumerate(train_bar): ...
2.2 加載常見的數(shù)據(jù)集
有些數(shù)據(jù)集是公共的,比如常見的MNIST,CIFAR10,SVHN等等。這些數(shù)據(jù)集在Pytorch中可以通過代碼就可以下載、加載。如下代碼所示。用torchvision中的datasets類下載數(shù)據(jù)集,并還是結(jié)合DataLoader來構(gòu)建可直接傳入網(wǎng)絡(luò)的數(shù)據(jù)裝載器。
from torch.utils.data import DataLoader from torchvision import datasets, transforms def dataloader(dataset, input_size, batch_size, split='train'): transform = transforms.Compose([ transforms.Resize((input_size, input_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) if dataset == 'mnist': data_loader = DataLoader( datasets.MNIST('data/mnist', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True) elif dataset == 'fashion-mnist': data_loader = DataLoader( datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True) elif dataset == 'cifar10': data_loader = DataLoader( datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True) elif dataset == 'svhn': data_loader = DataLoader( datasets.SVHN('data/svhn', split=split, download=True, transform=transform), batch_size=batch_size, shuffle=True) elif dataset == 'stl10': data_loader = DataLoader( datasets.STL10('data/stl10', split=split, download=True, transform=transform), batch_size=batch_size, shuffle=True) elif dataset == 'lsun-bed': data_loader = DataLoader( datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform), batch_size=batch_size, shuffle=True) return data_loader
三、總結(jié)
至于覺得加載數(shù)據(jù)集比較難的很大的原因,我感覺是Dataset、datasets、DataLoader以及torch.utils.data、torchvision種類太多,有點(diǎn)混亂。上面的梳理,我的理解是無論是哪種方式,終端還是需要DataLoader整合。作為加載數(shù)據(jù)集的前端,用自己定義的、用ImageFolder的、還是用datasets加載常用數(shù)據(jù)集,都是在構(gòu)造數(shù)據(jù)加載器,而且構(gòu)造起來也并不復(fù)雜。梳理清晰后,相信對Pytorch加載數(shù)據(jù)集有了更進(jìn)一步的理解。
四、transforms變換講解
torchvision.transforms是Pytorch中的圖像預(yù)處理包。一般定義在加載數(shù)據(jù)集之前,用transforms中的Compose類把多個(gè)步驟整合到一起,而這些步驟是transforms中的函數(shù)。
transforms中的函數(shù)有這些:
函數(shù) | 含義 |
---|---|
transforms.Resize | 把給定的圖片resize到given size |
transforms.Normalize | 用均值和標(biāo)準(zhǔn)差歸一化張量圖像 |
transforms.Totensor | 可以將PIL和numpy格式的數(shù)據(jù)從[0,255]范圍轉(zhuǎn)換到[0,1] ; <br /另外原始數(shù)據(jù)的shape是(H x W x C),通過transforms.ToTensor() 后shape會(huì)變?yōu)椋– x H x W) |
transforms.RandomGrayscale | 將圖像以一定的概率轉(zhuǎn)換為灰度圖像 |
transforms.ColorJitter | 隨機(jī)改變圖像的亮度對比度和飽和度 |
transforms.Centercrop | 在圖片的中間區(qū)域進(jìn)行裁剪 |
transforms.RandomCrop | 在一個(gè)隨機(jī)的位置進(jìn)行裁剪 |
transforms.FiceCrop | 把圖像裁剪為四個(gè)角和一個(gè)中心 |
transforms.RandomResizedCrop | 將PIL圖像裁剪成任意大小和縱橫比 |
transforms.ToPILImage | convert a tensor to PIL image |
transforms.RandomHorizontalFlip | 以0.5的概率水平翻轉(zhuǎn)給定的PIL圖像 |
transforms.RandomVerticalFlip | 以0.5的概率豎直翻轉(zhuǎn)給定的PIL圖像 |
transforms.Grayscale | 將圖像轉(zhuǎn)換為灰度圖像 |
不同函數(shù)對應(yīng)有不同的屬性,用transforms.Compose將不同的操作整合在一起,如下所示。
transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
五、DataLoader的補(bǔ)充
數(shù)據(jù)加載器,結(jié)合了數(shù)據(jù)集和取樣器,并且可以提供多個(gè)線程處理數(shù)據(jù)集。
在訓(xùn)練模型時(shí)使用到此函數(shù),用來把訓(xùn)練數(shù)據(jù)分成多個(gè)小組,此函數(shù)每次拋出一組數(shù)據(jù)。直至把所有的數(shù)據(jù)都拋出。就是做一個(gè)數(shù)據(jù)的初始化。
用下面的例子測試:
""" 批訓(xùn)練,把數(shù)據(jù)變成一小批一小批數(shù)據(jù)進(jìn)行訓(xùn)練。 DataLoader就是用來包裝所使用的數(shù)據(jù),每次拋出一批數(shù)據(jù) """ import torch import torch.utils.data as Data BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) # 把數(shù)據(jù)放在數(shù)據(jù)庫中 torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( # 從數(shù)據(jù)庫中每次抽出batch size個(gè)樣本 dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) if __name__ == '__main__': show_batch()
結(jié)果如下所示。仔細(xì)觀察:
每一個(gè)step,batch_x是不會(huì)重合的,batch_y里面的值也是不會(huì)重合的(第一個(gè)step中,batch_x:tensor([ 3., 10., 6., 2., 8.]);第二個(gè)step中batch_x:tensor([5., 9., 7., 4., 1.])),說明DataLoader將數(shù)據(jù)打亂后,每次選用其中的Batch_size個(gè)數(shù)據(jù)且不會(huì)重復(fù);
其二,batch_x 和 batch_y對應(yīng)的索引之和相等,這說明DataLoader對圖像和標(biāo)簽打亂順序時(shí),同時(shí)按照某一規(guī)律打亂,并不會(huì)造成標(biāo)簽和圖像出現(xiàn)不對應(yīng)的情況。
其三,在不同的epoch之間,每次數(shù)據(jù)也是不同的,說明DataLoader每次被調(diào)用時(shí),都會(huì)重新打亂一次。
steop:0, batch_x:tensor([ 3., 10., 6., 2., 8.]), batch_y:tensor([8., 1., 5., 9., 3.]) steop:1, batch_x:tensor([5., 9., 7., 4., 1.]), batch_y:tensor([ 6., 2., 4., 7., 10.]) steop:0, batch_x:tensor([8., 3., 1., 2., 9.]), batch_y:tensor([ 3., 8., 10., 9., 2.]) steop:1, batch_x:tensor([10., 5., 4., 7., 6.]), batch_y:tensor([1., 6., 7., 4., 5.]) steop:0, batch_x:tensor([5., 8., 4., 3., 7.]), batch_y:tensor([6., 3., 7., 8., 4.]) steop:1, batch_x:tensor([ 2., 10., 6., 9., 1.]), batch_y:tensor([ 9., 1., 5., 2., 10.])
總結(jié)
到此這篇關(guān)于Pytorch加載數(shù)據(jù)集的方式總結(jié)及補(bǔ)充的文章就介紹到這了,更多相關(guān)Pytorch加載數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python網(wǎng)絡(luò)編程之?dāng)?shù)據(jù)傳輸U(kuò)DP實(shí)例分析
這篇文章主要介紹了python網(wǎng)絡(luò)編程之?dāng)?shù)據(jù)傳輸U(kuò)DP實(shí)現(xiàn)方法,實(shí)例分析了Python基于UDP協(xié)議的數(shù)據(jù)傳輸實(shí)現(xiàn)方法,需要的朋友可以參考下2015-05-05Python破解BiliBili滑塊驗(yàn)證碼的思路詳解(完美避開人機(jī)識別)
這篇文章主要介紹了Python破解BiliBili滑塊驗(yàn)證碼的思路,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02Python中的遠(yuǎn)程調(diào)試與性能優(yōu)化技巧分享
Python 是一種簡單易學(xué)、功能強(qiáng)大的編程語言,廣泛應(yīng)用于各種領(lǐng)域,包括網(wǎng)絡(luò)編程、數(shù)據(jù)分析、人工智能等,在開發(fā)過程中,我們經(jīng)常會(huì)遇到需要遠(yuǎn)程調(diào)試和性能優(yōu)化的情況,本文將介紹如何利用遠(yuǎn)程調(diào)試工具和性能優(yōu)化技巧來提高 Python 應(yīng)用程序的效率和性能2024-05-05python中g(shù)etopt()函數(shù)用法詳解
這篇文章主要介紹了python中g(shù)etopt()函數(shù)用法,通過getopt模塊中的getopt(?)方法,我們可以獲取和解析命令行傳入的參數(shù),需要的朋友可以參考下2022-12-12python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲的示例代碼(可以自定義子彈速度)
這篇文章主要介紹了python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲---可以自定義子彈速度,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-08-08