Pytorch的torch.utils.data中Dataset以及DataLoader示例詳解
在我們進行深度學習的過程中,不免要用到數(shù)據(jù)集,那么數(shù)據(jù)集是如何加載到我們的模型中進行訓練的呢?以往我們大多數(shù)初學者肯定都是拿網(wǎng)上的代碼直接用,但是它底層的原理到底是什么還是不太清楚。所以今天就從內(nèi)置的Dataset函數(shù)和自定義的Dataset函數(shù)做一個詳細的解析。
前言
torch.utils.data 是 PyTorch 提供的一個模塊,用于處理和加載數(shù)據(jù)。該模塊提供了一系列工具類和函數(shù),用于創(chuàng)建、操作和批量加載數(shù)據(jù)集。
下面是 torch.utils.data 模塊中一些常用的類和函數(shù):
Dataset: 定義了抽象的數(shù)據(jù)集類,用戶可以通過繼承該類來構(gòu)建自己的數(shù)據(jù)集。Dataset類提供了兩個必須實現(xiàn)的方法:__getitem__用于訪問單個樣本,__len__用于返回數(shù)據(jù)集的大小。TensorDataset: 繼承自Dataset類,用于將張量數(shù)據(jù)打包成數(shù)據(jù)集。它接受多個張量作為輸入,并按照第一個輸入張量的大小來確定數(shù)據(jù)集的大小。DataLoader: 數(shù)據(jù)加載器類,用于批量加載數(shù)據(jù)集。它接受一個數(shù)據(jù)集對象作為輸入,并提供多種數(shù)據(jù)加載和預處理的功能,如設置批量大小、多線程數(shù)據(jù)加載和數(shù)據(jù)打亂等。Subset: 數(shù)據(jù)集的子集類,用于從數(shù)據(jù)集中選擇指定的樣本。random_split: 將一個數(shù)據(jù)集隨機劃分為多個子集,可以指定劃分的比例或指定每個子集的大小。ConcatDataset: 將多個數(shù)據(jù)集連接在一起形成一個更大的數(shù)據(jù)集。get_worker_info: 獲取當前數(shù)據(jù)加載器所在的進程信息。
除了上述的類和函數(shù)之外, torch.utils.data 還提供了一些常用的數(shù)據(jù)預處理的工具,如隨機裁剪、隨機旋轉(zhuǎn)、標準化等。
通過 torch.utils.data 模塊提供的類和函數(shù),可以方便地加載、處理和批量加載數(shù)據(jù),為模型訓練和驗證提供了便利。但是,我們最常用的兩個類還是 Dataset 和 DataLoader 類。
1、自定義Dataset類
torch.utils.data.Dataset 是 PyTorch 中用于表示數(shù)據(jù)集的抽象類,用于定義數(shù)據(jù)集的訪問方式和樣本數(shù)量。
Dataset 類是一個基類,我們可以通過繼承該類并實現(xiàn)下面兩個方法來創(chuàng)建自定義的數(shù)據(jù)集類:
getitem(self, index): 根據(jù)給定的索引 index,返回對應的樣本數(shù)據(jù)。索引可以是一個整數(shù),表示按順序獲取樣本,也可以是其他方式,如通過文件名獲取樣本等。len(self): 返回數(shù)據(jù)集中樣本的數(shù)量。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
# 根據(jù)索引獲取樣本
return self.data[index]
def __len__(self):
# 返回數(shù)據(jù)集大小
return len(self.data)
# 創(chuàng)建數(shù)據(jù)集對象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 根據(jù)索引獲取樣本
sample = dataset[2]
print(sample)
# 3上面的代碼樣例主要實現(xiàn)的是一個 自定義Dataset數(shù)據(jù)集類 的方法,這一般都是在我們需要訓練自己的數(shù)據(jù)時候需要定義的。但是一般我們作為深度學習初學者來講,使用的都是MNIST、CIFAR-10等 內(nèi)置數(shù)據(jù)集 ,這時候就不需要再自己定義Dataset類了。至于為什么,我們下面進行詳解。
2、torchvision.datasets
如果要使用PyTorch中的內(nèi)置數(shù)據(jù)集,通常是通過 torchvision.datasets 模塊來實現(xiàn)。 torchvision.datasets 模塊提供了許多常用的計算機視覺數(shù)據(jù)集,如MNIST、CIFAR10、ImageNet等。
下面是使用內(nèi)置數(shù)據(jù)集的示例代碼:
import torch
from torchvision import datasets, transforms
# 定義數(shù)據(jù)轉(zhuǎn)換
transform = transforms.Compose([
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Normalize((0.5,), (0.5,)) # 標準化圖像
])
# 加載MNIST數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)在上述代碼中,我們實現(xiàn)的便是一個內(nèi)置MNIST(手寫數(shù)字)數(shù)據(jù)集的加載和使用??梢钥吹剑覀冊谶@里面并未用到上面所提到的 torch.utils.data.Dataset 類,這是為什么呢?
這是因為在 torchvision.datasets 模塊中,內(nèi)置的數(shù)據(jù)集類已經(jīng)實現(xiàn)了 torch.utils.data.Dataset 接口,并直接返回一個可用的數(shù)據(jù)集對象。因此,在使用內(nèi)置數(shù)據(jù)集時,我們可以直接實例化內(nèi)置數(shù)據(jù)集類,而不需要顯式地繼承 torch.utils.data.Dataset 類。
內(nèi)置數(shù)據(jù)集類(如 torchvision.datasets.MNIST )的實現(xiàn)已經(jīng)包含了對 __getitem__ 和 __len__ 方法的定義,這使得我們可以直接從內(nèi)置數(shù)據(jù)集對象中獲取樣本和確定數(shù)據(jù)集的大小。這樣,我們在使用內(nèi)置數(shù)據(jù)集時可以直接將內(nèi)置數(shù)據(jù)集對象傳遞給 torch.utils.data.DataLoader 進行數(shù)據(jù)加載和批量處理。
在內(nèi)置數(shù)據(jù)集的背后,它們?nèi)匀皇腔?torch.utils.data.Dataset 類進行實現(xiàn),只是為了方便使用和提供更多功能,PyTorch 將這些常用數(shù)據(jù)集封裝成了內(nèi)置的數(shù)據(jù)集類。
為此,我專門到pytorch官網(wǎng)去查看了該內(nèi)置數(shù)據(jù)集的加載代碼,如下圖所示:

可以看出,確實以及內(nèi)置了Dataset數(shù)據(jù)集類。
3、DataLoader
torch.utils.data.DataLoader 是 PyTorch 中用于批量加載數(shù)據(jù)的工具類。它接受一個數(shù)據(jù)集對象(如 torch.utils.data.Dataset 的子類)并提供多種功能,如數(shù)據(jù)加載、批量處理、數(shù)據(jù)打亂等。
以下是 torch.utils.data.DataLoader 的常用參數(shù)和功能:
dataset: 數(shù)據(jù)集對象,可以是torch.utils.data.Dataset的子類對象。batch_size: 每個批次的樣本數(shù)量,默認為 1。shuffle: 是否對數(shù)據(jù)進行打亂,默認為False。在每個 epoch 時會重新打亂數(shù)據(jù)。num_workers: 使用多少個子進程加載數(shù)據(jù),默認為 0,表示在主進程中加載數(shù)據(jù)。其實在Windows系統(tǒng)里面都設置為0,但是在Linux中可以設置成大于0的數(shù)。collate_fn: 在返回批次數(shù)據(jù)之前,對每個樣本進行處理的函數(shù)。如果為None,默認使用torch.utils.data._utils.collate.default_collate函數(shù)進行處理。drop_last: 是否丟棄最后一個樣本數(shù)量不足一個批次的數(shù)據(jù),默認為False。pin_memory: 是否將加載的數(shù)據(jù)存放在 CUDA 對應的固定內(nèi)存中,默認為False。prefetch_factor: 預取因子,用于預取數(shù)據(jù)到設備,默認為 2。persistent_workers: 如果為True,則在每個 epoch 中使用持久的子進程進行數(shù)據(jù)加載,默認為False。
示例代碼如下:
import torch
from torchvision import datasets, transforms
# 定義數(shù)據(jù)轉(zhuǎn)換
transform = transforms.Compose([
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Normalize((0.5,), (0.5,)) # 標準化圖像
])
# 加載MNIST數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 創(chuàng)建數(shù)據(jù)加載器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 使用數(shù)據(jù)加載器迭代樣本
for images, labels in train_loader:
# 訓練模型的代碼
...4、torchvision.transforms
torchvision.transforms 模塊是PyTorch中用于圖像數(shù)據(jù)預處理的功能模塊。它提供了一系列的轉(zhuǎn)換函數(shù),用于在加載、訓練或推斷圖像數(shù)據(jù)時進行各種常見的數(shù)據(jù)變換和增強操作。下面是一些常用的轉(zhuǎn)換函數(shù)的詳細解釋:
Resize:調(diào)整圖像大小
Resize(size):將圖像調(diào)整為給定的尺寸??梢越邮芤粋€整數(shù)作為較短邊的大小,也可以接受一個元組或列表作為圖像的目標大小。
ToTensor:將圖像轉(zhuǎn)換為張量
ToTensor():將圖像轉(zhuǎn)換為張量,像素值范圍從0-255映射到0-1。適用于將圖像數(shù)據(jù)傳遞給深度學習模型。
Normalize:標準化圖像數(shù)據(jù)
Normalize(mean, std):對圖像數(shù)據(jù)進行標準化處理。傳入的mean和std是用于像素值歸一化的均值和標準差。需要注意的是,mean和std需要與之前使用的數(shù)據(jù)集相對應。
RandomHorizontalFlip:隨機水平翻轉(zhuǎn)圖像
RandomHorizontalFlip(p=0.5):以給定的概率對圖像進行隨機水平翻轉(zhuǎn)。概率p控制翻轉(zhuǎn)的概率,默認為0.5。
RandomCrop:隨機裁剪圖像
RandomCrop(size, padding=None):隨機裁剪圖像為給定的尺寸??梢蕴峁┮粋€元組或整數(shù)作為目標尺寸,并可選地提供填充值。
ColorJitter:顏色調(diào)整
ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):隨機調(diào)整圖像的亮度、對比度、飽和度和色調(diào)??梢酝ㄟ^設置不同的參數(shù)來調(diào)整圖像的樣貌。
在使用的時候,我們常常通過 transforms.Compose 來對這些數(shù)據(jù)處理操作進行一個組合,使用的時候,直接調(diào)用該組合即可。
示例代碼如下:
from torchvision import transforms
# 定義圖像預處理操作
transform = transforms.Compose([
transforms.Resize((256, 256)), # 縮放圖像大小為 (256, 256)
transforms.RandomCrop((224, 224)), # 隨機裁剪圖像為 (224, 224)
transforms.RandomHorizontalFlip(), # 隨機水平翻轉(zhuǎn)圖像
transforms.ToTensor(), # 將圖像轉(zhuǎn)換為張量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 標準化圖像
])
# 對圖像進行預處理
image = transform(image)5、圖像分類中Dataset數(shù)據(jù)集類的定義
就拿眼疾數(shù)據(jù)集來說(詳細可看深度學習實戰(zhàn)基礎案例——卷積神經(jīng)網(wǎng)絡(CNN)基于SqueezeNet的眼疾識別|第1例),其中我們對數(shù)據(jù)集進行標簽劃分以后,生成了train.txt以及valid.txt文件,該文件中分別為兩列,第一列為數(shù)據(jù)集的路徑,第二列為數(shù)據(jù)集的標簽(也就是類別),具體如下:

這時候我們就可以定義自己的數(shù)據(jù)集讀取類,具體代碼如下:
import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
transform_BZ = transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
class MyDataset(Dataset):
def __init__(self, txt_path, train_flag=True):
self.imgs_info = self.get_images(txt_path)
self.train_flag = train_flag
self.train_tf = transforms.Compose([
transforms.Resize(224), # 調(diào)整圖像大小為224x224
transforms.RandomHorizontalFlip(), # 隨機左右翻轉(zhuǎn)圖像
transforms.RandomVerticalFlip(), # 隨機上下翻轉(zhuǎn)圖像
transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間
transform_BZ # 執(zhí)行某些復雜變換操作
])
self.val_tf = transforms.Compose([
transforms.Resize(224), # 調(diào)整圖像大小為224x224
transforms.ToTensor(), # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間
transform_BZ # 執(zhí)行某些復雜變換操作
])
def get_images(self, txt_path):
with open(txt_path, 'r', encoding='utf-8') as f:
imgs_info = f.readlines()
imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
return imgs_info
def __getitem__(self, index):
img_path, label = self.imgs_info[index]
img_path = os.path.join('', img_path)
img = Image.open(img_path)
img = img.convert("RGB")
if self.train_flag:
img = self.train_tf(img)
else:
img = self.val_tf(img)
label = int(label)
return img, label
def __len__(self):
return len(self.imgs_info)定義完我們自己的數(shù)據(jù)集讀取類以后,就可以將我們的txt文件傳入進行數(shù)據(jù)集的預處理以及讀取工作。在我們的自定義dataset類里面,最重要的三個方法是__init__()、getitem()以及__len__(),這三個缺一不可。同時,transforms的數(shù)據(jù)增強操作也不是必須的,這不過是提高模型性能的一個方法而已,但是我們現(xiàn)在的模型訓練過程一般都會加上數(shù)據(jù)增強操作。
# 加載訓練集和驗證集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
shuffle=True, num_workers=0)上面,我們通過自定義的MyDataset類,分別加載了我們的train.txt文件以及valid.txt文件(后面的True參數(shù)代表我們要進行訓練集的數(shù)據(jù)增強,而False代表進行驗證集的數(shù)據(jù)增強)。然后,我們再通過我們的DataLoader來進行數(shù)據(jù)集的批量加載,之后就可以直接把加載好的 train_dl 和 test_dl 扔進模型里面訓練。
具體實例可參考:
深度學習實戰(zhàn)基礎案例——卷積神經(jīng)網(wǎng)絡(CNN)基于SqueezeNet的眼疾識別|第1例
Xception算法解析-鳥類識別實戰(zhàn)-Paddle實戰(zhàn)
到此這篇關于Pytorch的torch.utils.data中Dataset以及DataLoader等詳解的文章就介紹到這了,更多相關Pytorch Dataset及DataLoader內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python數(shù)據(jù)結(jié)構(gòu)之棧、隊列及二叉樹定義與用法淺析
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)之棧、隊列及二叉樹定義與用法,結(jié)合具體實例形式分析了Python數(shù)據(jù)結(jié)構(gòu)中棧、隊列及二叉樹的定義與使用相關操作技巧,需要的朋友可以參考下2018-12-12
python格式化字符串的實戰(zhàn)教程(使用占位符、format方法)
我們經(jīng)常會用到%-formatting和str.format()來格式化,下面這篇文章主要給大家介紹了關于python格式化字符串的相關資料,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下2022-08-08
解決python3.6 右鍵沒有 Edit with IDLE的問題
這篇文章主要介紹了解決python3.6 右鍵沒有 Edit with IDLE的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-03-03

