欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch的torch.utils.data中Dataset以及DataLoader示例詳解

 更新時(shí)間:2023年08月23日 15:05:55   作者:心無旁騖~  
torch.utils.data?是?PyTorch?提供的一個(gè)模塊,用于處理和加載數(shù)據(jù),該模塊提供了一系列工具類和函數(shù),用于創(chuàng)建、操作和批量加載數(shù)據(jù)集,這篇文章主要介紹了Pytorch的torch.utils.data中Dataset以及DataLoader等詳解,需要的朋友可以參考下

在我們進(jìn)行深度學(xué)習(xí)的過程中,不免要用到數(shù)據(jù)集,那么數(shù)據(jù)集是如何加載到我們的模型中進(jìn)行訓(xùn)練的呢?以往我們大多數(shù)初學(xué)者肯定都是拿網(wǎng)上的代碼直接用,但是它底層的原理到底是什么還是不太清楚。所以今天就從內(nèi)置的Dataset函數(shù)和自定義的Dataset函數(shù)做一個(gè)詳細(xì)的解析。

前言

torch.utils.data PyTorch 提供的一個(gè)模塊,用于處理和加載數(shù)據(jù)。該模塊提供了一系列工具類和函數(shù),用于創(chuàng)建、操作和批量加載數(shù)據(jù)集。

下面是 torch.utils.data 模塊中一些常用的類和函數(shù):

  • Dataset : 定義了抽象的數(shù)據(jù)集類,用戶可以通過繼承該類來構(gòu)建自己的數(shù)據(jù)集。 Dataset 類提供了兩個(gè)必須實(shí)現(xiàn)的方法: __getitem__ 用于訪問單個(gè)樣本, __len__ 用于返回?cái)?shù)據(jù)集的大小。
  • TensorDataset : 繼承自 Dataset 類,用于將張量數(shù)據(jù)打包成數(shù)據(jù)集。它接受多個(gè)張量作為輸入,并按照第一個(gè)輸入張量的大小來確定數(shù)據(jù)集的大小。
  • DataLoader : 數(shù)據(jù)加載器類,用于批量加載數(shù)據(jù)集。它接受一個(gè)數(shù)據(jù)集對(duì)象作為輸入,并提供多種數(shù)據(jù)加載和預(yù)處理的功能,如設(shè)置批量大小、多線程數(shù)據(jù)加載和數(shù)據(jù)打亂等。
  • Subset : 數(shù)據(jù)集的子集類,用于從數(shù)據(jù)集中選擇指定的樣本。
  • random_split : 將一個(gè)數(shù)據(jù)集隨機(jī)劃分為多個(gè)子集,可以指定劃分的比例或指定每個(gè)子集的大小。
  • ConcatDataset : 將多個(gè)數(shù)據(jù)集連接在一起形成一個(gè)更大的數(shù)據(jù)集。
  • get_worker_info : 獲取當(dāng)前數(shù)據(jù)加載器所在的進(jìn)程信息。

除了上述的類和函數(shù)之外, torch.utils.data 還提供了一些常用的數(shù)據(jù)預(yù)處理的工具,如隨機(jī)裁剪、隨機(jī)旋轉(zhuǎn)、標(biāo)準(zhǔn)化等。

通過 torch.utils.data 模塊提供的類和函數(shù),可以方便地加載、處理和批量加載數(shù)據(jù),為模型訓(xùn)練和驗(yàn)證提供了便利。但是,我們最常用的兩個(gè)類還是 Dataset DataLoader 類。

1、自定義Dataset類

torch.utils.data.Dataset 是 PyTorch 中用于表示數(shù)據(jù)集的抽象類,用于定義數(shù)據(jù)集的訪問方式和樣本數(shù)量。

Dataset 類是一個(gè)基類,我們可以通過繼承該類并實(shí)現(xiàn)下面兩個(gè)方法來創(chuàng)建自定義的數(shù)據(jù)集類:

getitem(self, index): 根據(jù)給定的索引 index,返回對(duì)應(yīng)的樣本數(shù)據(jù)。索引可以是一個(gè)整數(shù),表示按順序獲取樣本,也可以是其他方式,如通過文件名獲取樣本等。len(self): 返回?cái)?shù)據(jù)集中樣本的數(shù)量。

import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __getitem__(self, index):
        # 根據(jù)索引獲取樣本
        return self.data[index]
    def __len__(self):
        # 返回?cái)?shù)據(jù)集大小
        return len(self.data)
# 創(chuàng)建數(shù)據(jù)集對(duì)象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 根據(jù)索引獲取樣本
sample = dataset[2]
print(sample)
# 3

上面的代碼樣例主要實(shí)現(xiàn)的是一個(gè) 自定義Dataset數(shù)據(jù)集類 的方法,這一般都是在我們需要訓(xùn)練自己的數(shù)據(jù)時(shí)候需要定義的。但是一般我們作為深度學(xué)習(xí)初學(xué)者來講,使用的都是MNIST、CIFAR-10等 內(nèi)置數(shù)據(jù)集 ,這時(shí)候就不需要再自己定義Dataset類了。至于為什么,我們下面進(jìn)行詳解。

2、torchvision.datasets

如果要使用PyTorch中的內(nèi)置數(shù)據(jù)集,通常是通過 torchvision.datasets 模塊來實(shí)現(xiàn)。 torchvision.datasets 模塊提供了許多常用的計(jì)算機(jī)視覺數(shù)據(jù)集,如MNIST、CIFAR10、ImageNet等。

下面是使用內(nèi)置數(shù)據(jù)集的示例代碼:

import torch
from torchvision import datasets, transforms
# 定義數(shù)據(jù)轉(zhuǎn)換
transform = transforms.Compose([
    transforms.ToTensor(),  # 將圖像轉(zhuǎn)換為張量
    transforms.Normalize((0.5,), (0.5,))  # 標(biāo)準(zhǔn)化圖像
])
# 加載MNIST數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

在上述代碼中,我們實(shí)現(xiàn)的便是一個(gè)內(nèi)置MNIST(手寫數(shù)字)數(shù)據(jù)集的加載和使用。可以看到,我們?cè)谶@里面并未用到上面所提到的 torch.utils.data.Dataset 類,這是為什么呢?

這是因?yàn)樵?torchvision.datasets 模塊中,內(nèi)置的數(shù)據(jù)集類已經(jīng)實(shí)現(xiàn)了 torch.utils.data.Dataset 接口,并直接返回一個(gè)可用的數(shù)據(jù)集對(duì)象。因此,在使用內(nèi)置數(shù)據(jù)集時(shí),我們可以直接實(shí)例化內(nèi)置數(shù)據(jù)集類,而不需要顯式地繼承 torch.utils.data.Dataset 類。

內(nèi)置數(shù)據(jù)集類(如 torchvision.datasets.MNIST )的實(shí)現(xiàn)已經(jīng)包含了對(duì) __getitem__ __len__ 方法的定義,這使得我們可以直接從內(nèi)置數(shù)據(jù)集對(duì)象中獲取樣本和確定數(shù)據(jù)集的大小。這樣,我們?cè)谑褂脙?nèi)置數(shù)據(jù)集時(shí)可以直接將內(nèi)置數(shù)據(jù)集對(duì)象傳遞給 torch.utils.data.DataLoader 進(jìn)行數(shù)據(jù)加載和批量處理。

在內(nèi)置數(shù)據(jù)集的背后,它們?nèi)匀皇腔?torch.utils.data.Dataset 類進(jìn)行實(shí)現(xiàn),只是為了方便使用和提供更多功能,PyTorch 將這些常用數(shù)據(jù)集封裝成了內(nèi)置的數(shù)據(jù)集類。

為此,我專門到pytorch官網(wǎng)去查看了該內(nèi)置數(shù)據(jù)集的加載代碼,如下圖所示:

在這里插入圖片描述

可以看出,確實(shí)以及內(nèi)置了Dataset數(shù)據(jù)集類。

3、DataLoader

torch.utils.data.DataLoader 是 PyTorch 中用于批量加載數(shù)據(jù)的工具類。它接受一個(gè)數(shù)據(jù)集對(duì)象(如 torch.utils.data.Dataset 的子類)并提供多種功能,如數(shù)據(jù)加載、批量處理、數(shù)據(jù)打亂等。

以下是 torch.utils.data.DataLoader 的常用參數(shù)和功能:

  • dataset : 數(shù)據(jù)集對(duì)象,可以是 torch.utils.data.Dataset 的子類對(duì)象。
  • batch_size : 每個(gè)批次的樣本數(shù)量,默認(rèn)為 1。 shuffle : 是否對(duì)數(shù)據(jù)進(jìn)行打亂,默認(rèn)為 False 。在每個(gè) epoch 時(shí)會(huì)重新打亂數(shù)據(jù)。
  • num_workers : 使用多少個(gè)子進(jìn)程加載數(shù)據(jù),默認(rèn)為 0,表示在主進(jìn)程中加載數(shù)據(jù)。其實(shí)在Windows系統(tǒng)里面都設(shè)置為0,但是在Linux中可以設(shè)置成大于0的數(shù)。 collate_fn : 在返回批次數(shù)據(jù)之前,對(duì)每個(gè)樣本進(jìn)行處理的函數(shù)。如果為 None ,默認(rèn)使用 torch.utils.data._utils.collate.default_collate 函數(shù)進(jìn)行處理。
  • drop_last : 是否丟棄最后一個(gè)樣本數(shù)量不足一個(gè)批次的數(shù)據(jù),默認(rèn)為 False 。
  • pin_memory : 是否將加載的數(shù)據(jù)存放在 CUDA 對(duì)應(yīng)的固定內(nèi)存中,默認(rèn)為 False 。
  • prefetch_factor : 預(yù)取因子,用于預(yù)取數(shù)據(jù)到設(shè)備,默認(rèn)為 2。 persistent_workers : 如果為 True ,則在每個(gè) epoch 中使用持久的子進(jìn)程進(jìn)行數(shù)據(jù)加載,默認(rèn)為 False 。

示例代碼如下:

import torch
from torchvision import datasets, transforms
# 定義數(shù)據(jù)轉(zhuǎn)換
transform = transforms.Compose([
    transforms.ToTensor(),  # 將圖像轉(zhuǎn)換為張量
    transforms.Normalize((0.5,), (0.5,))  # 標(biāo)準(zhǔn)化圖像
])
# 加載MNIST數(shù)據(jù)集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 創(chuàng)建數(shù)據(jù)加載器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 使用數(shù)據(jù)加載器迭代樣本
for images, labels in train_loader:
    # 訓(xùn)練模型的代碼
    ...

4、torchvision.transforms

torchvision.transforms 模塊是PyTorch中用于圖像數(shù)據(jù)預(yù)處理的功能模塊。它提供了一系列的轉(zhuǎn)換函數(shù),用于在加載、訓(xùn)練或推斷圖像數(shù)據(jù)時(shí)進(jìn)行各種常見的數(shù)據(jù)變換和增強(qiáng)操作。下面是一些常用的轉(zhuǎn)換函數(shù)的詳細(xì)解釋:

Resize:調(diào)整圖像大小

  • Resize(size) :將圖像調(diào)整為給定的尺寸??梢越邮芤粋€(gè)整數(shù)作為較短邊的大小,也可以接受一個(gè)元組或列表作為圖像的目標(biāo)大小。

ToTensor:將圖像轉(zhuǎn)換為張量

  • ToTensor() :將圖像轉(zhuǎn)換為張量,像素值范圍從0-255映射到0-1。適用于將圖像數(shù)據(jù)傳遞給深度學(xué)習(xí)模型。

Normalize:標(biāo)準(zhǔn)化圖像數(shù)據(jù)

  • Normalize(mean, std) :對(duì)圖像數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化處理。傳入的mean和std是用于像素值歸一化的均值和標(biāo)準(zhǔn)差。需要注意的是,mean和std需要與之前使用的數(shù)據(jù)集相對(duì)應(yīng)。

RandomHorizontalFlip:隨機(jī)水平翻轉(zhuǎn)圖像

  • RandomHorizontalFlip(p=0.5) :以給定的概率對(duì)圖像進(jìn)行隨機(jī)水平翻轉(zhuǎn)。概率p控制翻轉(zhuǎn)的概率,默認(rèn)為0.5。

RandomCrop:隨機(jī)裁剪圖像

  • RandomCrop(size, padding=None) :隨機(jī)裁剪圖像為給定的尺寸。可以提供一個(gè)元組或整數(shù)作為目標(biāo)尺寸,并可選地提供填充值。

ColorJitter:顏色調(diào)整

  • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) :隨機(jī)調(diào)整圖像的亮度、對(duì)比度、飽和度和色調(diào)??梢酝ㄟ^設(shè)置不同的參數(shù)來調(diào)整圖像的樣貌。

在使用的時(shí)候,我們常常通過 transforms.Compose 來對(duì)這些數(shù)據(jù)處理操作進(jìn)行一個(gè)組合,使用的時(shí)候,直接調(diào)用該組合即可。

示例代碼如下:

from torchvision import transforms
# 定義圖像預(yù)處理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 縮放圖像大小為 (256, 256)
    transforms.RandomCrop((224, 224)),  # 隨機(jī)裁剪圖像為 (224, 224)
    transforms.RandomHorizontalFlip(),  # 隨機(jī)水平翻轉(zhuǎn)圖像
    transforms.ToTensor(),  # 將圖像轉(zhuǎn)換為張量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 標(biāo)準(zhǔn)化圖像
])
# 對(duì)圖像進(jìn)行預(yù)處理
image = transform(image)

5、圖像分類中Dataset數(shù)據(jù)集類的定義

就拿眼疾數(shù)據(jù)集來說(詳細(xì)可看深度學(xué)習(xí)實(shí)戰(zhàn)基礎(chǔ)案例——卷積神經(jīng)網(wǎng)絡(luò)(CNN)基于SqueezeNet的眼疾識(shí)別|第1例),其中我們對(duì)數(shù)據(jù)集進(jìn)行標(biāo)簽劃分以后,生成了train.txt以及valid.txt文件,該文件中分別為兩列,第一列為數(shù)據(jù)集的路徑,第二列為數(shù)據(jù)集的標(biāo)簽(也就是類別),具體如下:

在這里插入圖片描述

這時(shí)候我們就可以定義自己的數(shù)據(jù)集讀取類,具體代碼如下:

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
transform_BZ = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5]
)
class MyDataset(Dataset):
    def __init__(self, txt_path, train_flag=True):
        self.imgs_info = self.get_images(txt_path)
        self.train_flag = train_flag
        self.train_tf = transforms.Compose([
            transforms.Resize(224),  # 調(diào)整圖像大小為224x224
            transforms.RandomHorizontalFlip(),  # 隨機(jī)左右翻轉(zhuǎn)圖像
            transforms.RandomVerticalFlip(),  # 隨機(jī)上下翻轉(zhuǎn)圖像
            transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間
            transform_BZ  # 執(zhí)行某些復(fù)雜變換操作
        ])
        self.val_tf = transforms.Compose([
            transforms.Resize(224),  # 調(diào)整圖像大小為224x224
            transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉(zhuǎn)換為tensor,并歸一化到[0,1]之間
            transform_BZ  # 執(zhí)行某些復(fù)雜變換操作
        ])
    def get_images(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))
        return imgs_info
    def __getitem__(self, index):
        img_path, label = self.imgs_info[index]
        img_path = os.path.join('', img_path)
        img = Image.open(img_path)
        img = img.convert("RGB")
        if self.train_flag:
            img = self.train_tf(img)
        else:
            img = self.val_tf(img)
        label = int(label)
        return img, label
    def __len__(self):
        return len(self.imgs_info)

定義完我們自己的數(shù)據(jù)集讀取類以后,就可以將我們的txt文件傳入進(jìn)行數(shù)據(jù)集的預(yù)處理以及讀取工作。在我們的自定義dataset類里面,最重要的三個(gè)方法是__init__()、getitem()以及__len__(),這三個(gè)缺一不可。同時(shí),transforms的數(shù)據(jù)增強(qiáng)操作也不是必須的,這不過是提高模型性能的一個(gè)方法而已,但是我們現(xiàn)在的模型訓(xùn)練過程一般都會(huì)加上數(shù)據(jù)增強(qiáng)操作。

# 加載訓(xùn)練集和驗(yàn)證集
train_data = MyDataset(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)
test_data = MyDataset(r"F:\SqueezeNet\valid.txt", False)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,
                                           shuffle=True, num_workers=0)

上面,我們通過自定義的MyDataset類,分別加載了我們的train.txt文件以及valid.txt文件(后面的True參數(shù)代表我們要進(jìn)行訓(xùn)練集的數(shù)據(jù)增強(qiáng),而False代表進(jìn)行驗(yàn)證集的數(shù)據(jù)增強(qiáng))。然后,我們?cè)偻ㄟ^我們的DataLoader來進(jìn)行數(shù)據(jù)集的批量加載,之后就可以直接把加載好的 train_dl test_dl 扔進(jìn)模型里面訓(xùn)練。

具體實(shí)例可參考:

深度學(xué)習(xí)實(shí)戰(zhàn)基礎(chǔ)案例——卷積神經(jīng)網(wǎng)絡(luò)(CNN)基于SqueezeNet的眼疾識(shí)別|第1例

Xception算法解析-鳥類識(shí)別實(shí)戰(zhàn)-Paddle實(shí)戰(zhàn)

到此這篇關(guān)于Pytorch的torch.utils.data中Dataset以及DataLoader等詳解的文章就介紹到這了,更多相關(guān)Pytorch Dataset及DataLoader內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python獲取目錄下所有文件的方法

    python獲取目錄下所有文件的方法

    這篇文章主要介紹了python獲取目錄下所有文件的方法,實(shí)例分析了Python中os模塊下walk方法的使用技巧,需要的朋友可以參考下
    2015-06-06
  • python封裝成exe的超詳細(xì)教程

    python封裝成exe的超詳細(xì)教程

    相信很多人都很想把python文件封裝成exe文件,下面這篇文章主要給大家介紹了關(guān)于python封裝成exe的相關(guān)資料,文中通過圖文介紹的非常詳細(xì),需要的朋友可以參考下
    2022-06-06
  • 通過Python爬蟲代理IP快速增加博客閱讀量

    通過Python爬蟲代理IP快速增加博客閱讀量

    本文主要對(duì)通過Python爬蟲代理IP快速增加博客閱讀量的方法進(jìn)行分析介紹。具有很好的參考價(jià)值,需要的朋友一起來看下吧
    2016-12-12
  • pandas讀取csv文件提示不存在的解決方法及原因分析

    pandas讀取csv文件提示不存在的解決方法及原因分析

    這篇文章主要介紹了pandas讀取csv文件提示不存在的解決方法及原因分析,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • opencv基于Haar人臉檢測(cè)和眼睛檢測(cè)

    opencv基于Haar人臉檢測(cè)和眼睛檢測(cè)

    這篇文章主要為大家詳細(xì)介紹了opencv基于Haar人臉檢測(cè)和眼睛檢測(cè),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2021-09-09
  • Python數(shù)據(jù)結(jié)構(gòu)之棧、隊(duì)列及二叉樹定義與用法淺析

    Python數(shù)據(jù)結(jié)構(gòu)之棧、隊(duì)列及二叉樹定義與用法淺析

    這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)之棧、隊(duì)列及二叉樹定義與用法,結(jié)合具體實(shí)例形式分析了Python數(shù)據(jù)結(jié)構(gòu)中棧、隊(duì)列及二叉樹的定義與使用相關(guān)操作技巧,需要的朋友可以參考下
    2018-12-12
  • python讀取ini配置文件過程示范

    python讀取ini配置文件過程示范

    這篇文章主要介紹了python讀取ini配置文件過程示范,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-12-12
  • python格式化字符串的實(shí)戰(zhàn)教程(使用占位符、format方法)

    python格式化字符串的實(shí)戰(zhàn)教程(使用占位符、format方法)

    我們經(jīng)常會(huì)用到%-formatting和str.format()來格式化,下面這篇文章主要給大家介紹了關(guān)于python格式化字符串的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-08-08
  • Python3.0 實(shí)現(xiàn)決策樹算法的流程

    Python3.0 實(shí)現(xiàn)決策樹算法的流程

    這篇文章主要介紹了Python3.0 實(shí)現(xiàn)決策樹算法的流程,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-08-08
  • 解決python3.6 右鍵沒有 Edit with IDLE的問題

    解決python3.6 右鍵沒有 Edit with IDLE的問題

    這篇文章主要介紹了解決python3.6 右鍵沒有 Edit with IDLE的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2021-03-03

最新評(píng)論