Pytorch的torch.utils.data中Dataset以及DataLoader示例詳解
在我們進(jìn)行深度學(xué)習(xí)的過程中,不免要用到數(shù)據(jù)集,那么數(shù)據(jù)集是如何加載到我們的模型中進(jìn)行訓(xùn)練的呢?以往我們大多數(shù)初學(xué)者肯定都是拿網(wǎng)上的代碼直接用,但是它底層的原理到底是什么還是不太清楚。所以今天就從內(nèi)置的Dataset函數(shù)和自定義的Dataset函數(shù)做一個(gè)詳細(xì)的解析。
前言
torch.utils.data
是 PyTorch
提供的一個(gè)模塊,用于處理和加載數(shù)據(jù)。該模塊提供了一系列工具類和函數(shù),用于創(chuàng)建、操作和批量加載數(shù)據(jù)集。
下面是 torch.utils.data
模塊中一些常用的類和函數(shù):
Dataset
: 定義了抽象的數(shù)據(jù)集類,用戶可以通過繼承該類來構(gòu)建自己的數(shù)據(jù)集。Dataset
類提供了兩個(gè)必須實(shí)現(xiàn)的方法:__getitem__
用于訪問單個(gè)樣本,__len__
用于返回?cái)?shù)據(jù)集的大小。TensorDataset
: 繼承自Dataset
類,用于將張量數(shù)據(jù)打包成數(shù)據(jù)集。它接受多個(gè)張量作為輸入,并按照第一個(gè)輸入張量的大小來確定數(shù)據(jù)集的大小。DataLoader
: 數(shù)據(jù)加載器類,用于批量加載數(shù)據(jù)集。它接受一個(gè)數(shù)據(jù)集對(duì)象作為輸入,并提供多種數(shù)據(jù)加載和預(yù)處理的功能,如設(shè)置批量大小、多線程數(shù)據(jù)加載和數(shù)據(jù)打亂等。Subset
: 數(shù)據(jù)集的子集類,用于從數(shù)據(jù)集中選擇指定的樣本。random_split
: 將一個(gè)數(shù)據(jù)集隨機(jī)劃分為多個(gè)子集,可以指定劃分的比例或指定每個(gè)子集的大小。ConcatDataset
: 將多個(gè)數(shù)據(jù)集連接在一起形成一個(gè)更大的數(shù)據(jù)集。get_worker_info
: 獲取當(dāng)前數(shù)據(jù)加載器所在的進(jìn)程信息。
除了上述的類和函數(shù)之外, torch.utils.data
還提供了一些常用的數(shù)據(jù)預(yù)處理的工具,如隨機(jī)裁剪、隨機(jī)旋轉(zhuǎn)、標(biāo)準(zhǔn)化等。
通過 torch.utils.data
模塊提供的類和函數(shù),可以方便地加載、處理和批量加載數(shù)據(jù),為模型訓(xùn)練和驗(yàn)證提供了便利。但是,我們最常用的兩個(gè)類還是 Dataset
和 DataLoader
類。
1、自定義Dataset類
torch.utils.data.Dataset
是 PyTorch 中用于表示數(shù)據(jù)集的抽象類,用于定義數(shù)據(jù)集的訪問方式和樣本數(shù)量。
Dataset 類是一個(gè)基類,我們可以通過繼承該類并實(shí)現(xiàn)下面兩個(gè)方法來創(chuàng)建自定義的數(shù)據(jù)集類:
getitem(self, index): 根據(jù)給定的索引 index,返回對(duì)應(yīng)的樣本數(shù)據(jù)。索引可以是一個(gè)整數(shù),表示按順序獲取樣本,也可以是其他方式,如通過文件名獲取樣本等。len(self): 返回?cái)?shù)據(jù)集中樣本的數(shù)量。
import torch from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, index): # 根據(jù)索引獲取樣本 return self.data[index] def __len__(self): # 返回?cái)?shù)據(jù)集大小 return len(self.data) # 創(chuàng)建數(shù)據(jù)集對(duì)象 data = [1, 2, 3, 4, 5] dataset = MyDataset(data) # 根據(jù)索引獲取樣本 sample = dataset[2] print(sample) # 3
上面的代碼樣例主要實(shí)現(xiàn)的是一個(gè) 自定義Dataset數(shù)據(jù)集類
的方法,這一般都是在我們需要訓(xùn)練自己的數(shù)據(jù)時(shí)候需要定義的。但是一般我們作為深度學(xué)習(xí)初學(xué)者來講,使用的都是MNIST、CIFAR-10等 內(nèi)置數(shù)據(jù)集
,這時(shí)候就不需要再自己定義Dataset類了。至于為什么,我們下面進(jìn)行詳解。
2、torchvision.datasets
如果要使用PyTorch中的內(nèi)置數(shù)據(jù)集,通常是通過 torchvision.datasets
模塊來實(shí)現(xiàn)。 torchvision.datasets
模塊提供了許多常用的計(jì)算機(jī)視覺數(shù)據(jù)集,如MNIST、CIFAR10、ImageNet等。
下面是使用內(nèi)置數(shù)據(jù)集的示例代碼:
import torch from torchvision import datasets, transforms # 定義數(shù)據(jù)轉(zhuǎn)換 transform = transforms.Compose([ transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量 transforms.Normalize((0.5,), (0.5,)) # 標(biāo)準(zhǔn)化圖像 ]) # 加載MNIST數(shù)據(jù)集 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
在上述代碼中,我們實(shí)現(xiàn)的便是一個(gè)內(nèi)置MNIST(手寫數(shù)字)數(shù)據(jù)集的加載和使用。可以看到,我們?cè)谶@里面并未用到上面所提到的 torch.utils.data.Dataset
類,這是為什么呢?
這是因?yàn)樵?torchvision.datasets
模塊中,內(nèi)置的數(shù)據(jù)集類已經(jīng)實(shí)現(xiàn)了 torch.utils.data.Dataset
接口,并直接返回一個(gè)可用的數(shù)據(jù)集對(duì)象。因此,在使用內(nèi)置數(shù)據(jù)集時(shí),我們可以直接實(shí)例化內(nèi)置數(shù)據(jù)集類,而不需要顯式地繼承 torch.utils.data.Dataset
類。
內(nèi)置數(shù)據(jù)集類(如 torchvision.datasets.MNIST
)的實(shí)現(xiàn)已經(jīng)包含了對(duì) __getitem__
和 __len__
方法的定義,這使得我們可以直接從內(nèi)置數(shù)據(jù)集對(duì)象中獲取樣本和確定數(shù)據(jù)集的大小。這樣,我們?cè)谑褂脙?nèi)置數(shù)據(jù)集時(shí)可以直接將內(nèi)置數(shù)據(jù)集對(duì)象傳遞給 torch.utils.data.DataLoader
進(jìn)行數(shù)據(jù)加載和批量處理。
在內(nèi)置數(shù)據(jù)集的背后,它們?nèi)匀皇腔?torch.utils.data.Dataset
類進(jìn)行實(shí)現(xiàn),只是為了方便使用和提供更多功能,PyTorch 將這些常用數(shù)據(jù)集封裝成了內(nèi)置的數(shù)據(jù)集類。
為此,我專門到pytorch官網(wǎng)去查看了該內(nèi)置數(shù)據(jù)集的加載代碼,如下圖所示:
可以看出,確實(shí)以及內(nèi)置了Dataset數(shù)據(jù)集類。
3、DataLoader
torch.utils.data.DataLoader
是 PyTorch 中用于批量加載數(shù)據(jù)的工具類。它接受一個(gè)數(shù)據(jù)集對(duì)象(如 torch.utils.data.Dataset
的子類)并提供多種功能,如數(shù)據(jù)加載、批量處理、數(shù)據(jù)打亂等。
以下是 torch.utils.data.DataLoader
的常用參數(shù)和功能:
dataset
: 數(shù)據(jù)集對(duì)象,可以是torch.utils.data.Dataset
的子類對(duì)象。batch_size
: 每個(gè)批次的樣本數(shù)量,默認(rèn)為 1。shuffle
: 是否對(duì)數(shù)據(jù)進(jìn)行打亂,默認(rèn)為False
。在每個(gè) epoch 時(shí)會(huì)重新打亂數(shù)據(jù)。num_workers
: 使用多少個(gè)子進(jìn)程加載數(shù)據(jù),默認(rèn)為 0,表示在主進(jìn)程中加載數(shù)據(jù)。其實(shí)在Windows系統(tǒng)里面都設(shè)置為0,但是在Linux中可以設(shè)置成大于0的數(shù)。collate_fn
: 在返回批次數(shù)據(jù)之前,對(duì)每個(gè)樣本進(jìn)行處理的函數(shù)。如果為None
,默認(rèn)使用torch.utils.data._utils.collate.default_collate
函數(shù)進(jìn)行處理。drop_last
: 是否丟棄最后一個(gè)樣本數(shù)量不足一個(gè)批次的數(shù)據(jù),默認(rèn)為False
。pin_memory
: 是否將加載的數(shù)據(jù)存放在 CUDA 對(duì)應(yīng)的固定內(nèi)存中,默認(rèn)為False
。prefetch_factor
: 預(yù)取因子,用于預(yù)取數(shù)據(jù)到設(shè)備,默認(rèn)為 2。persistent_workers
: 如果為True
,則在每個(gè) epoch 中使用持久的子進(jìn)程進(jìn)行數(shù)據(jù)加載,默認(rèn)為False
。
示例代碼如下:
import torch from torchvision import datasets, transforms # 定義數(shù)據(jù)轉(zhuǎn)換 transform = transforms.Compose([ transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量 transforms.Normalize((0.5,), (0.5,)) # 標(biāo)準(zhǔn)化圖像 ]) # 加載MNIST數(shù)據(jù)集 train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 創(chuàng)建數(shù)據(jù)加載器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4) # 使用數(shù)據(jù)加載器迭代樣本 for images, labels in train_loader: # 訓(xùn)練模型的代碼 ...
4、torchvision.transforms
torchvision.transforms
模塊是PyTorch中用于圖像數(shù)據(jù)預(yù)處理的功能模塊。它提供了一系列的轉(zhuǎn)換函數(shù),用于在加載、訓(xùn)練或推斷圖像數(shù)據(jù)時(shí)進(jìn)行各種常見的數(shù)據(jù)變換和增強(qiáng)操作。下面是一些常用的轉(zhuǎn)換函數(shù)的詳細(xì)解釋:
Resize:調(diào)整圖像大小
Resize(size)
:將圖像調(diào)整為給定的尺寸??梢越邮芤粋€(gè)整數(shù)作為較短邊的大小,也可以接受一個(gè)元組或列表作為圖像的目標(biāo)大小。
ToTensor:將圖像轉(zhuǎn)換為張量
ToTensor()
:將圖像轉(zhuǎn)換為張量,像素值范圍從0-255映射到0-1。適用于將圖像數(shù)據(jù)傳遞給深度學(xué)習(xí)模型。
Normalize:標(biāo)準(zhǔn)化圖像數(shù)據(jù)
Normalize(mean, std)
:對(duì)圖像數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化處理。傳入的mean和std是用于像素值歸一化的均值和標(biāo)準(zhǔn)差。需要注意的是,mean和std需要與之前使用的數(shù)據(jù)集相對(duì)應(yīng)。
RandomHorizontalFlip:隨機(jī)水平翻轉(zhuǎn)圖像
RandomHorizontalFlip(p=0.5)
:以給定的概率對(duì)圖像進(jìn)行隨機(jī)水平翻轉(zhuǎn)。概率p控制翻轉(zhuǎn)的概率,默認(rèn)為0.5。
RandomCrop:隨機(jī)裁剪圖像
RandomCrop(size, padding=None)
:隨機(jī)裁剪圖像為給定的尺寸。可以提供一個(gè)元組或整數(shù)作為目標(biāo)尺寸,并可選地提供填充值。
ColorJitter:顏色調(diào)整
ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
:隨機(jī)調(diào)整圖像的亮度、對(duì)比度、飽和度和色調(diào)??梢酝ㄟ^設(shè)置不同的參數(shù)來調(diào)整圖像的樣貌。
在使用的時(shí)候,我們常常通過 transforms.Compose
來對(duì)這些數(shù)據(jù)處理操作進(jìn)行一個(gè)組合,使用的時(shí)候,直接調(diào)用該組合即可。
示例代碼如下:
from torchvision import transforms # 定義圖像預(yù)處理操作 transform = transforms.Compose([ transforms.Resize((256, 256)), # 縮放圖像大小為 (256, 256) transforms.RandomCrop((224, 224)), # 隨機(jī)裁剪圖像為 (224, 224) transforms.RandomHorizontalFlip(), # 隨機(jī)水平翻轉(zhuǎn)圖像 transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標(biāo)準(zhǔn)化圖像 ]) # 對(duì)圖像進(jìn)行預(yù)處理 image = transform(image)
5、圖像分類中Dataset數(shù)據(jù)集類的定義
就拿眼疾數(shù)據(jù)集來說(詳細(xì)可看深度學(xué)習(xí)實(shí)戰(zhàn)基礎(chǔ)案例——卷積神經(jīng)網(wǎng)絡(luò)(CNN)基于SqueezeNet的眼疾識(shí)別|第1例),其中我們對(duì)數(shù)據(jù)集進(jìn)行標(biāo)簽劃分以后,生成了train.txt以及valid.txt文件,該文件中分別為兩列,第一列為數(shù)據(jù)集的路徑,第二列為數(shù)據(jù)集的標(biāo)簽(也就是類別),具體如下:
這時(shí)候我們就可以定義自己的數(shù)據(jù)集讀取類,具體代碼如下:
import os.path from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision.transforms import transforms transform_BZ = transforms.Normalize( mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) class MyDataset(Dataset): def __init__(self, txt_path, train_flag=True): self.imgs_info = self.get_images(txt_path) self.train_flag = train_flag self.train_tf = transforms.Compose([ transforms.Resize(224), # 調(diào)整圖像大小為224x224 transforms.RandomHorizontalFlip(), # 隨機(jī)左右翻轉(zhuǎn)圖像 transforms.RandomVerticalFlip(), # 隨機(jī)上下翻轉(zhuǎn)圖像 transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間 transform_BZ # 執(zhí)行某些復(fù)雜變換操作 ]) self.val_tf = transforms.Compose([ transforms.Resize(224), # 調(diào)整圖像大小為224x224 transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間 transform_BZ # 執(zhí)行某些復(fù)雜變換操作 ]) def get_images(self, txt_path): with open(txt_path, 'r', encoding='utf-8') as f: imgs_info = f.readlines() imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info)) return imgs_info def __getitem__(self, index): img_path, label = self.imgs_info[index] img_path = os.path.join('', img_path) img = Image.open(img_path) img = img.convert("RGB") if self.train_flag: img = self.train_tf(img) else: img = self.val_tf(img) label = int(label) return img, label def __len__(self): return len(self.imgs_info)
定義完我們自己的數(shù)據(jù)集讀取類以后,就可以將我們的txt文件傳入進(jìn)行數(shù)據(jù)集的預(yù)處理以及讀取工作。在我們的自定義dataset類里面,最重要的三個(gè)方法是__init__()、getitem()以及__len__(),這三個(gè)缺一不可。同時(shí),transforms的數(shù)據(jù)增強(qiáng)操作也不是必須的,這不過是提高模型性能的一個(gè)方法而已,但是我們現(xiàn)在的模型訓(xùn)練過程一般都會(huì)加上數(shù)據(jù)增強(qiáng)操作。
# 加載訓(xùn)練集和驗(yàn)證集 train_data = MyDataset(r"F:\SqueezeNet\train.txt", True) train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True, shuffle=True, num_workers=0) test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False) test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True, shuffle=True, num_workers=0)
上面,我們通過自定義的MyDataset類,分別加載了我們的train.txt文件以及valid.txt文件(后面的True參數(shù)代表我們要進(jìn)行訓(xùn)練集的數(shù)據(jù)增強(qiáng),而False代表進(jìn)行驗(yàn)證集的數(shù)據(jù)增強(qiáng))。然后,我們?cè)偻ㄟ^我們的DataLoader來進(jìn)行數(shù)據(jù)集的批量加載,之后就可以直接把加載好的 train_dl
和 test_dl
扔進(jìn)模型里面訓(xùn)練。
具體實(shí)例可參考:
深度學(xué)習(xí)實(shí)戰(zhàn)基礎(chǔ)案例——卷積神經(jīng)網(wǎng)絡(luò)(CNN)基于SqueezeNet的眼疾識(shí)別|第1例
Xception算法解析-鳥類識(shí)別實(shí)戰(zhàn)-Paddle實(shí)戰(zhàn)
到此這篇關(guān)于Pytorch的torch.utils.data中Dataset以及DataLoader等詳解的文章就介紹到這了,更多相關(guān)Pytorch Dataset及DataLoader內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python數(shù)據(jù)結(jié)構(gòu)之棧、隊(duì)列及二叉樹定義與用法淺析
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)之棧、隊(duì)列及二叉樹定義與用法,結(jié)合具體實(shí)例形式分析了Python數(shù)據(jù)結(jié)構(gòu)中棧、隊(duì)列及二叉樹的定義與使用相關(guān)操作技巧,需要的朋友可以參考下2018-12-12python格式化字符串的實(shí)戰(zhàn)教程(使用占位符、format方法)
我們經(jīng)常會(huì)用到%-formatting和str.format()來格式化,下面這篇文章主要給大家介紹了關(guān)于python格式化字符串的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-08-08Python3.0 實(shí)現(xiàn)決策樹算法的流程
這篇文章主要介紹了Python3.0 實(shí)現(xiàn)決策樹算法的流程,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-08-08解決python3.6 右鍵沒有 Edit with IDLE的問題
這篇文章主要介紹了解決python3.6 右鍵沒有 Edit with IDLE的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-03-03