Pytorch建模過程中的DataLoader與Dataset示例詳解
處理數(shù)據(jù)樣本的代碼會因為處理過程繁雜而變得混亂且難以維護,在理想情況下,我們希望數(shù)據(jù)預(yù)處理過程代碼與我們的模型訓(xùn)練代碼分離,以獲得更好的可讀性和模塊化,為此,PyTorch提供了torch.utils.data.DataLoader
和 torch.utils.data.Dataset
兩個類用于數(shù)據(jù)處理。其中torch.utils.data.DataLoader
用于將數(shù)據(jù)集進行打包封裝成一個可迭代對象,torch.utils.data.Dataset
存儲有一些常用的數(shù)據(jù)集示例以及相關(guān)標(biāo)簽。
同時PyTorch針對不同的專業(yè)領(lǐng)域,也提供有不同的模塊,例如 TorchText
(自然語言處理), TorchVision
(計算機視覺), TorchAudio
(音頻),這些模塊中也都包含一些真實數(shù)據(jù)集示例。例如TorchVision
模塊中提供了CIFAR, COCO, FashionMNIST 數(shù)據(jù)集。
1 定義數(shù)據(jù)集
pytorch中提供兩種風(fēng)格的數(shù)據(jù)集定義方式:
- 字典映射風(fēng)格。之所以稱為映射風(fēng)格,是因為在后續(xù)加載數(shù)據(jù)迭代時,pytorch將自動使用迭代索引作為key,通過字典索引的方式獲取value,本質(zhì)就是將數(shù)據(jù)集定義為一個字典,使用這種風(fēng)格時,需要繼承
Dataset
類。
In [54]:
from torch.utils.data import Dataset from torch.utils.data import DataLoader
In [56]:
dataset = {0: '張三', 1:'李四', 2:'王五', 3:'趙六', 4:'陳七'} dataloader = DataLoader(dataset, batch_size=2) for i, value in enumerate(dataloader): print(i, value)
0 ['張三', '李四'] 1 ['王五', '趙六'] 2 ['陳七']
- 迭代器風(fēng)格。在自定義數(shù)據(jù)集類中,實現(xiàn)
__iter__
和__next__
方法,即定義為迭代器,在后續(xù)加載數(shù)據(jù)迭代時,pytorch將依次獲取value,使用這種風(fēng)格時,需要繼承IterableDataset
類。這種方法在數(shù)據(jù)量巨大,無法一下全部加載到內(nèi)存時非常實用。
In [57]:
from torch.utils.data import DataLoader from torch.utils.data import IterableDataset
In [58]:
dataset = [i for i in range(10)] dataloader = DataLoader(dataset=dataset, batch_size=3, shuffle=True) for i, item in enumerate(dataloader): # 迭代輸出 print(i, item)
0 tensor([3, 1, 2]) 1 tensor([9, 7, 5]) 2 tensor([0, 8, 4]) 3 tensor([6])
如下所示,我們有一個螞蟻蜜蜂圖像分類數(shù)據(jù)集,目錄結(jié)構(gòu)如下所示,下面我們結(jié)合這個數(shù)據(jù)集,分別介紹如何使用這兩個類定義真實數(shù)據(jù)集。
data └── hymenoptera_data ├── train │?? ├── ants │?? │?? ├── 0013035.jpg │ │ …… │?? └── bees │?? ├── 1092977343_cb42b38d62.jpg │ …… └── val ├── ants │?? ├── 10308379_1b6c72e180.jpg │?? …… └── bees ├── 1032546534_06907fe3b3.jpg ……
1.2 Dataset類
自定義一個Dataset類,繼承torch.utils.data.Dataset,且必須實現(xiàn)下面三個方法:
Dataset類里面的
__init__
函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件。Dataset類里面的
__getitem__
函數(shù),映射取值是調(diào)用的方法,獲取單個的數(shù)據(jù),訓(xùn)練迭代時將會調(diào)用這個方法。Dataset類里面的
__len__
函數(shù)獲取數(shù)據(jù)的總量。
In [211]:
import os import pandas as pd from PIL import Image from torchvision.transforms import ToTensor, Lambda from torchvision import transforms import torchvision class AntBeeDataset(Dataset): # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) def __getitem__(self, idx): img_path, label = self.img_lst[idx] img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) # 這個地方要注意,我們在計算loss的時候用交叉熵nn.CrossEntropyLoss() # 交叉熵的輸入有兩個,一個是模型的輸出outputs,一個是標(biāo)簽targets,注意targets是一維tensor # 例如batchsize如果是2,ants的targets的應(yīng)該[0,0],而不是[[0][0]] # 因此label要返回0,而不是[0] return img, label def __len__(self): return len(self.img_lst)
In [310]:
train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), # 將給定圖像隨機裁剪為不同的大小和寬高比,然后縮放所裁剪得到的圖像為制定的大小 transforms.RandomHorizontalFlip(), # 以給定的概率隨機水平旋轉(zhuǎn)給定的PIL的圖像,默認為0.5 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 驗證集并不需要做與訓(xùn)練集相同的處理,所有,通常使用更加簡單的transformer val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 根據(jù)標(biāo)簽?zāi)夸浀拿Q來確定圖片是哪一類,如果是"ants",標(biāo)簽設(shè)置為0,如果是"bees",標(biāo)簽設(shè)置為1 target_transform = transforms.Lambda(lambda y: 0 if y == "ants" else 1)
In [311]:
train_dataset = AntBeeDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform) val_dataset = AntBeeDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)
1.2 Dataset數(shù)據(jù)集常用操作
1. 查看數(shù)據(jù)集大小:
In [221]:
len(train_dataset), len(val_dataset)
Out[221]:
(245, 153)
2. 合并數(shù)據(jù)集
In [222]:
dataset = train_dataset + val_dataset
In [223]:
len(dataset)
Out[223]:
398
3. 劃分訓(xùn)練集、測試集
In [224]:
from torch.utils.data import random_split # random_split 不能直接使用百分比劃分,必須指定具體數(shù)字 train_size = int( len(dataset) * 0.8) test_size = len(dataset) - train_size
In [225]:
train_dataset, val_dataset = random_split(dataset, [train_size, test_size])
In [226]:
len(train_dataset), len(val_dataset)
Out[226]:
(318, 80)
1.3 IterableDataset類
使用迭代器風(fēng)格時,必須繼承IterableDataset
類,且實現(xiàn)下面兩個方法:
__init__
,函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件,在數(shù)據(jù)量過大時,通常只是獲取操作句柄、數(shù)據(jù)庫連接。__iter__
,獲取迭代器。
雖然只需要實現(xiàn)這兩個方法,但是通常還需要在迭代過程中對數(shù)據(jù)進行處理。IterableDataset類實現(xiàn)自定義數(shù)據(jù)集,本質(zhì)就是創(chuàng)建一個數(shù)據(jù)集類,且實現(xiàn)__iter__
返回一個迭代器。一下提供兩種方法通過IterableDataset類自定義數(shù)據(jù)集:
方法一:
In [289]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) def __iter__(self): for img_path, label in self.img_lst: img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) yield img, label
方法二:
In [285]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, root_dir, transform=None, target_transform=None): """ root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.root_dir = root_dir self.transform = transform self.target_transform = target_transform # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽 self.img_lst = [] for label in ['ants', 'bees']: path = os.path.join(root_dir, label) for img_name in os.listdir(path): self.img_lst.append((os.path.join(root_dir, label, img_name), label)) self.index = 0 def __iter__(self): return self def __next__(self): try: img_path, label = self.img_lst[self.index] self.index += 1 img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) return img, label except IndexError: raise StopIteration()
In [290]:
train_dataset = AntBeeIterableDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform) val_dataset = AntBeeIterableDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)
在處理大數(shù)據(jù)集時,IterableDataset會比Dataset更有優(yōu)勢,例如數(shù)據(jù)存儲在文件或者數(shù)據(jù)庫中,只需要在自定義的IterableDataset之類中獲取文件操作句柄或者數(shù)據(jù)庫連接和游標(biāo)驚喜迭代,每次只返回一條數(shù)據(jù)即可。我們把上文中螞蟻蜜蜂數(shù)據(jù)集的所有圖片、標(biāo)簽這里后寫入hymenoptera_data.txt中,內(nèi)容如下所示,假設(shè)有數(shù)億行,那么,就不能直接將數(shù)據(jù)加載到內(nèi)存了:
data/hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg, ants data/hymenoptera_data/train/ants/2278278459_6b99605e50.jpg, ants data/hymenoptera_data/train/ants/543417860_b14237f569.jpg, ants ... ...
可以參考一下方式定義IterableDataset子類:
In [299]:
class AntBeeIterableDataset(IterableDataset): # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到 def __init__(self, filepath, transform=None, target_transform=None): """ filepath:hymenoptera_data.txt完整路徑 transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值 """ self.filepath = filepath self.transform = transform self.target_transform = target_transform def __iter__(self): with open(self.filepath, 'r') as f: for line in f: img_path, label = line.replace('\n', '').split(', ') img = Image.open(img_path).convert('RGB') if self.transform: img = self.transform(img) if self.target_transform: label = self.target_transform(label) yield img, label
In [307]:
train_dataset = AntBeeIterableDataset('hymenoptera_data.txt', transform=train_transform, target_transform=target_transform)
注意,IterableDataset方法在處理大數(shù)據(jù)集時確實比Dataset更有優(yōu)勢,但是,IterableDataset在迭代過程中,樣本輸出順序是固定的,在使用DataLoader進行加載時,無法使用shuffle進行打亂,同時,因為在IterableDataset中并未強制限定必須實現(xiàn)__len__()
方法(很多時候確實也沒法獲取數(shù)據(jù)總量),不能通過len()
方法獲取數(shù)據(jù)總量。
2 DataLoad
DataLoader的功能是構(gòu)建可迭代的數(shù)據(jù)裝載器,在訓(xùn)練的時候,每一個for循環(huán),每一次Iteration,就是從DataLoader中獲取一個batch_size大小的數(shù)據(jù),節(jié)省內(nèi)存的同時,它還可以實現(xiàn)多進程、數(shù)據(jù)打亂等處理。我們通過一張圖來了解DataLoader數(shù)據(jù)讀取機制:
首先,在for循環(huán)中使用了DataLoader,進入DataLoader后,首先根據(jù)是否使用多進程DataLoaderIter,做出判斷之后單線程還是多線程,接著使用Sampler得索引Index,然后將索引給到DatasetFetcher,在這里面調(diào)用Dataset,根據(jù)索引,通過getitem得到實際的數(shù)據(jù)和標(biāo)簽,得到一個batch size大小的數(shù)據(jù)后,通過collate_fn函數(shù)整理成一個Batch Data的形式輸入到模型去訓(xùn)練。
在pytorch建模的數(shù)據(jù)處理、加載流程中,DataLoader應(yīng)該算是最核心的一步操作DataLoader有很多參數(shù),這里我們列出常用的幾個:
- dataset:表示Dataset類,它決定了數(shù)據(jù)從哪讀取以及如何讀?。?/li>
- batch_size:表示批大??;
- num_works:表示是否多進程讀取數(shù)據(jù);
- shuffle:表示每個epoch是否亂序;
- drop_last:表示當(dāng)樣本數(shù)不能被batch_size整除時,是否舍棄最后一批數(shù)據(jù);
- num_workers:啟動多少個進程來加載數(shù)據(jù)。
我們重點說說多進程模式下使用DataLoader,在多進程模式下,每次 DataLoader 創(chuàng)建 iterator 時(遍歷DataLoader時,例如,當(dāng)調(diào)用時enumerate(dataloader)),都會創(chuàng)建 num_workers 工作進程。dataset, collate_fn, worker_init_fn 都會被傳到每個worker中,每個worker都用獨立的進程。
對于映射風(fēng)格的數(shù)據(jù)集,即Dataset子類,主線程會用Sampler(采樣器)產(chǎn)生indice,并將它們送到進程里。因此,shuffle是在主線程做的
對于迭代器風(fēng)格的數(shù)據(jù)集,即IterableDataset子類,因為每個進程都有相同的data復(fù)制樣本,并在各個進程里進行不同的操作,以防止每個進程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來進行輔助處理。
這里,torch.utils.data.get_worker_info() 返回worker進程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None
注意,通常不建議在多進程加載中返回CUDA張量,因為在使用CUDA和在多處理中共享CUDA張量時存在許多微妙之處(文檔中提出:只要接收過程保留張量的副本,就需要發(fā)送過程來保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡而言之,不建議在使用多線程的情況下返回CUDA的tensor。
In [313]:
dataload = DataLoader(train_dataset, batch_size=2)
In [315]:
img, label = next(iter(dataload))
In [316]:
img.shape, label
Out[316]:
(torch.Size([2, 3, 224, 224]), tensor([0, 0]))
到此這篇關(guān)于Pytorch建模過程中的DataLoader與Dataset的文章就介紹到這了,更多相關(guān)Pytorch建模內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- Pytorch建模過程中的DataLoader與Dataset示例詳解
- Pytorch如何加載自己的數(shù)據(jù)集(使用DataLoader讀取Dataset)
- PyTorch?Dataset與DataLoader使用超詳細講解
- Pytorch數(shù)據(jù)讀取之Dataset和DataLoader知識總結(jié)
- Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作
- pytorch Dataset,DataLoader產(chǎn)生自定義的訓(xùn)練數(shù)據(jù)案例
- PyTorch實現(xiàn)重寫/改寫Dataset并載入Dataloader
- 一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關(guān)系
- PyTorch 解決Dataset和Dataloader遇到的問題
相關(guān)文章
python中的線程threading.Thread()使用詳解
這篇文章主要介紹了python中的線程threading.Thread()使用詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12python lambda表達式(匿名函數(shù))寫法解析
這篇文章主要介紹了python lambda表達式(匿名函數(shù))寫法解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-09-09