欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch建模過程中的DataLoader與Dataset示例詳解

 更新時間:2023年01月05日 08:51:28   作者:奧辰  
這篇文章主要介紹了Pytorch建模過程中的DataLoader與Dataset,同時PyTorch針對不同的專業(yè)領(lǐng)域,也提供有不同的模塊,例如?TorchText,?TorchVision,?TorchAudio,這些模塊中也都包含一些真實數(shù)據(jù)集示例,本文給大家介紹的非常詳細,需要的朋友參考下吧

處理數(shù)據(jù)樣本的代碼會因為處理過程繁雜而變得混亂且難以維護,在理想情況下,我們希望數(shù)據(jù)預(yù)處理過程代碼與我們的模型訓(xùn)練代碼分離,以獲得更好的可讀性和模塊化,為此,PyTorch提供了torch.utils.data.DataLoader 和 torch.utils.data.Dataset兩個類用于數(shù)據(jù)處理。其中torch.utils.data.DataLoader用于將數(shù)據(jù)集進行打包封裝成一個可迭代對象,torch.utils.data.Dataset存儲有一些常用的數(shù)據(jù)集示例以及相關(guān)標(biāo)簽。

同時PyTorch針對不同的專業(yè)領(lǐng)域,也提供有不同的模塊,例如 TorchText(自然語言處理), TorchVision(計算機視覺), TorchAudio(音頻),這些模塊中也都包含一些真實數(shù)據(jù)集示例。例如TorchVision模塊中提供了CIFAR, COCO, FashionMNIST 數(shù)據(jù)集。

1 定義數(shù)據(jù)集

pytorch中提供兩種風(fēng)格的數(shù)據(jù)集定義方式:

  • 字典映射風(fēng)格。之所以稱為映射風(fēng)格,是因為在后續(xù)加載數(shù)據(jù)迭代時,pytorch將自動使用迭代索引作為key,通過字典索引的方式獲取value,本質(zhì)就是將數(shù)據(jù)集定義為一個字典,使用這種風(fēng)格時,需要繼承Dataset類。

In [54]:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [56]:

dataset = {0: '張三', 1:'李四', 2:'王五', 3:'趙六', 4:'陳七'}
dataloader = DataLoader(dataset, batch_size=2)
for i, value in enumerate(dataloader):
    print(i, value)
0 ['張三', '李四']
1 ['王五', '趙六']
2 ['陳七']
  • 迭代器風(fēng)格。在自定義數(shù)據(jù)集類中,實現(xiàn)__iter____next__方法,即定義為迭代器,在后續(xù)加載數(shù)據(jù)迭代時,pytorch將依次獲取value,使用這種風(fēng)格時,需要繼承IterableDataset類。這種方法在數(shù)據(jù)量巨大,無法一下全部加載到內(nèi)存時非常實用。

In [57]:

from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset

In [58]:

dataset = [i for i in range(10)]
dataloader = DataLoader(dataset=dataset, batch_size=3, shuffle=True) 
for i, item in enumerate(dataloader): # 迭代輸出
    print(i, item)
0 tensor([3, 1, 2])
1 tensor([9, 7, 5])
2 tensor([0, 8, 4])
3 tensor([6])

如下所示,我們有一個螞蟻蜜蜂圖像分類數(shù)據(jù)集,目錄結(jié)構(gòu)如下所示,下面我們結(jié)合這個數(shù)據(jù)集,分別介紹如何使用這兩個類定義真實數(shù)據(jù)集。

data
└── hymenoptera_data
    ├── train
    │?? ├── ants
    │?? │?? ├── 0013035.jpg
    │   │   ……
    │?? └── bees
    │??     ├── 1092977343_cb42b38d62.jpg
    │       ……
    └── val
        ├── ants
        │?? ├── 10308379_1b6c72e180.jpg
        │?? ……
        └── bees
            ├── 1032546534_06907fe3b3.jpg
            ……

1.2 Dataset類

自定義一個Dataset類,繼承torch.utils.data.Dataset,且必須實現(xiàn)下面三個方法:

  • Dataset類里面的__init__函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件。

  • Dataset類里面的__getitem__函數(shù),映射取值是調(diào)用的方法,獲取單個的數(shù)據(jù),訓(xùn)練迭代時將會調(diào)用這個方法。

  • Dataset類里面的__len__函數(shù)獲取數(shù)據(jù)的總量。

In [211]:

import os
import pandas as pd
from PIL import Image
from torchvision.transforms import ToTensor, Lambda
from torchvision import transforms
import torchvision
class AntBeeDataset(Dataset):
    # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize
        target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
        
    def __getitem__(self, idx):
        img_path, label = self.img_lst[idx]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        # 這個地方要注意,我們在計算loss的時候用交叉熵nn.CrossEntropyLoss()
        # 交叉熵的輸入有兩個,一個是模型的輸出outputs,一個是標(biāo)簽targets,注意targets是一維tensor
        # 例如batchsize如果是2,ants的targets的應(yīng)該[0,0],而不是[[0][0]]
        # 因此label要返回0,而不是[0]
        return img, label

    def __len__(self):
        return len(self.img_lst)

In [310]:

train_transform = transforms.Compose([
    
    transforms.RandomResizedCrop(224),  # 將給定圖像隨機裁剪為不同的大小和寬高比,然后縮放所裁剪得到的圖像為制定的大小
    transforms.RandomHorizontalFlip(),  # 以給定的概率隨機水平旋轉(zhuǎn)給定的PIL的圖像,默認為0.5
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# 驗證集并不需要做與訓(xùn)練集相同的處理,所有,通常使用更加簡單的transformer
val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# 根據(jù)標(biāo)簽?zāi)夸浀拿Q來確定圖片是哪一類,如果是"ants",標(biāo)簽設(shè)置為0,如果是"bees",標(biāo)簽設(shè)置為1
target_transform = transforms.Lambda(lambda y: 0 if y == "ants" else 1)

In [311]:

train_dataset = AntBeeDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform)
val_dataset = AntBeeDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)

1.2 Dataset數(shù)據(jù)集常用操作

1. 查看數(shù)據(jù)集大小:

In [221]:

len(train_dataset), len(val_dataset)

Out[221]:

(245, 153)

2. 合并數(shù)據(jù)集

In [222]:

dataset = train_dataset + val_dataset

In [223]:

len(dataset)

Out[223]:

398

3. 劃分訓(xùn)練集、測試集

In [224]:

from torch.utils.data import random_split
# random_split 不能直接使用百分比劃分,必須指定具體數(shù)字
train_size = int( len(dataset) * 0.8)
test_size = len(dataset) - train_size

In [225]:

train_dataset, val_dataset = random_split(dataset, [train_size, test_size])

In [226]:

len(train_dataset), len(val_dataset)

Out[226]:

(318, 80)

1.3 IterableDataset類

使用迭代器風(fēng)格時,必須繼承IterableDataset類,且實現(xiàn)下面兩個方法:

  • __init__,函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件,在數(shù)據(jù)量過大時,通常只是獲取操作句柄、數(shù)據(jù)庫連接。

  • __iter__,獲取迭代器。

雖然只需要實現(xiàn)這兩個方法,但是通常還需要在迭代過程中對數(shù)據(jù)進行處理。IterableDataset類實現(xiàn)自定義數(shù)據(jù)集,本質(zhì)就是創(chuàng)建一個數(shù)據(jù)集類,且實現(xiàn)__iter__返回一個迭代器。一下提供兩種方法通過IterableDataset類自定義數(shù)據(jù)集:

方法一:

In [289]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize
        target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
                
    def __iter__(self):
        for img_path, label in self.img_lst:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            if self.target_transform:
                label = self.target_transform(label)
            yield img, label

方法二:

In [285]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸?,這是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize
        target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
        self.index = 0
                
    def __iter__(self):
        return self
    
    def __next__(self):
        try:
            img_path, label = self.img_lst[self.index]
            self.index += 1
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            if self.target_transform:
                label = self.target_transform(label)
            return img, label
        except IndexError:
            raise StopIteration()

In [290]:

train_dataset = AntBeeIterableDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform)
val_dataset = AntBeeIterableDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)

在處理大數(shù)據(jù)集時,IterableDataset會比Dataset更有優(yōu)勢,例如數(shù)據(jù)存儲在文件或者數(shù)據(jù)庫中,只需要在自定義的IterableDataset之類中獲取文件操作句柄或者數(shù)據(jù)庫連接和游標(biāo)驚喜迭代,每次只返回一條數(shù)據(jù)即可。我們把上文中螞蟻蜜蜂數(shù)據(jù)集的所有圖片、標(biāo)簽這里后寫入hymenoptera_data.txt中,內(nèi)容如下所示,假設(shè)有數(shù)億行,那么,就不能直接將數(shù)據(jù)加載到內(nèi)存了:

data/hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg, ants
data/hymenoptera_data/train/ants/2278278459_6b99605e50.jpg, ants
data/hymenoptera_data/train/ants/543417860_b14237f569.jpg, ants
...
...

可以參考一下方式定義IterableDataset子類:

In [299]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因為標(biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, filepath, transform=None, target_transform=None):
        """
        filepath:hymenoptera_data.txt完整路徑
        transform: 對圖像數(shù)據(jù)進行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進行resize
        target_transform:對標(biāo)簽數(shù)據(jù)進行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.filepath = filepath
        self.transform = transform
        self.target_transform = target_transform

                
    def __iter__(self):
        with open(self.filepath, 'r') as f:
            for line in f:
                img_path, label = line.replace('\n', '').split(', ')
                img = Image.open(img_path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                if self.target_transform:
                    label = self.target_transform(label)
                yield img, label

In [307]:

train_dataset = AntBeeIterableDataset('hymenoptera_data.txt', transform=train_transform, target_transform=target_transform)

注意,IterableDataset方法在處理大數(shù)據(jù)集時確實比Dataset更有優(yōu)勢,但是,IterableDataset在迭代過程中,樣本輸出順序是固定的,在使用DataLoader進行加載時,無法使用shuffle進行打亂,同時,因為在IterableDataset中并未強制限定必須實現(xiàn)__len__()方法(很多時候確實也沒法獲取數(shù)據(jù)總量),不能通過len()方法獲取數(shù)據(jù)總量。

2 DataLoad

DataLoader的功能是構(gòu)建可迭代的數(shù)據(jù)裝載器,在訓(xùn)練的時候,每一個for循環(huán),每一次Iteration,就是從DataLoader中獲取一個batch_size大小的數(shù)據(jù),節(jié)省內(nèi)存的同時,它還可以實現(xiàn)多進程、數(shù)據(jù)打亂等處理。我們通過一張圖來了解DataLoader數(shù)據(jù)讀取機制:

首先,在for循環(huán)中使用了DataLoader,進入DataLoader后,首先根據(jù)是否使用多進程DataLoaderIter,做出判斷之后單線程還是多線程,接著使用Sampler得索引Index,然后將索引給到DatasetFetcher,在這里面調(diào)用Dataset,根據(jù)索引,通過getitem得到實際的數(shù)據(jù)和標(biāo)簽,得到一個batch size大小的數(shù)據(jù)后,通過collate_fn函數(shù)整理成一個Batch Data的形式輸入到模型去訓(xùn)練。

在pytorch建模的數(shù)據(jù)處理、加載流程中,DataLoader應(yīng)該算是最核心的一步操作DataLoader有很多參數(shù),這里我們列出常用的幾個:

  • dataset:表示Dataset類,它決定了數(shù)據(jù)從哪讀取以及如何讀?。?/li>
  • batch_size:表示批大??;
  • num_works:表示是否多進程讀取數(shù)據(jù);
  • shuffle:表示每個epoch是否亂序;
  • drop_last:表示當(dāng)樣本數(shù)不能被batch_size整除時,是否舍棄最后一批數(shù)據(jù);
  • num_workers:啟動多少個進程來加載數(shù)據(jù)。

我們重點說說多進程模式下使用DataLoader,在多進程模式下,每次 DataLoader 創(chuàng)建 iterator 時(遍歷DataLoader時,例如,當(dāng)調(diào)用時enumerate(dataloader)),都會創(chuàng)建 num_workers 工作進程。dataset, collate_fn, worker_init_fn 都會被傳到每個worker中,每個worker都用獨立的進程。

對于映射風(fēng)格的數(shù)據(jù)集,即Dataset子類,主線程會用Sampler(采樣器)產(chǎn)生indice,并將它們送到進程里。因此,shuffle是在主線程做的

對于迭代器風(fēng)格的數(shù)據(jù)集,即IterableDataset子類,因為每個進程都有相同的data復(fù)制樣本,并在各個進程里進行不同的操作,以防止每個進程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來進行輔助處理。

這里,torch.utils.data.get_worker_info() 返回worker進程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None

注意,通常不建議在多進程加載中返回CUDA張量,因為在使用CUDA和在多處理中共享CUDA張量時存在許多微妙之處(文檔中提出:只要接收過程保留張量的副本,就需要發(fā)送過程來保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡而言之,不建議在使用多線程的情況下返回CUDA的tensor。

In [313]:

dataload = DataLoader(train_dataset, batch_size=2)

In [315]:

img, label = next(iter(dataload))

In [316]:

img.shape, label

Out[316]:

(torch.Size([2, 3, 224, 224]), tensor([0, 0]))

到此這篇關(guān)于Pytorch建模過程中的DataLoader與Dataset的文章就介紹到這了,更多相關(guān)Pytorch建模內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Django 登陸驗證碼和中間件的實現(xiàn)

    Django 登陸驗證碼和中間件的實現(xiàn)

    這篇文章主要介紹了Django 登陸驗證碼和中間件的實現(xiàn),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-08-08
  • python處理變量交換與字符串及判斷的小妙招

    python處理變量交換與字符串及判斷的小妙招

    本文記錄一些 Python 日常編程中的小妙招,并使用 IPython 進行交互測試,讓我們更好的了解和學(xué)習(xí) Python 的一些特性,對大家的學(xué)習(xí)或工作具有一定的價值,需要的朋友可以參考下
    2021-09-09
  • Pipenv輕量級虛擬環(huán)境管理工具使用指南

    Pipenv輕量級虛擬環(huán)境管理工具使用指南

    這篇文章主要為大家介紹了Pipenv輕量級虛擬環(huán)境管理工具使用指南,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-02-02
  • Python?中enum的使用方法總結(jié)

    Python?中enum的使用方法總結(jié)

    這篇文章主要介紹了Python?中enum的使用方法總結(jié),枚舉在許多編程語言中常被表示為一種基礎(chǔ)的數(shù)據(jù)結(jié)構(gòu)使用,下文更多詳細內(nèi)容需要的小伙伴可以參考一下
    2022-03-03
  • Python多繼承以及MRO順序的使用

    Python多繼承以及MRO順序的使用

    這篇文章主要介紹了Python多繼承以及MRO順序的使用,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-11-11
  • python中的線程threading.Thread()使用詳解

    python中的線程threading.Thread()使用詳解

    這篇文章主要介紹了python中的線程threading.Thread()使用詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-12-12
  • Python海象運算符代碼分析及知識點總結(jié)

    Python海象運算符代碼分析及知識點總結(jié)

    在本篇內(nèi)容里小編給大家總結(jié)了關(guān)于Python海象運算符的使用的相關(guān)內(nèi)容及代碼,有興趣的朋友們跟著學(xué)習(xí)下。
    2022-11-11
  • python lambda表達式(匿名函數(shù))寫法解析

    python lambda表達式(匿名函數(shù))寫法解析

    這篇文章主要介紹了python lambda表達式(匿名函數(shù))寫法解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-09-09
  • 使用python制作一個壓縮圖片小程序

    使用python制作一個壓縮圖片小程序

    這篇文章主要為大家詳細介紹了如何使用python制作一個壓縮圖片小程序,文中的示例代碼簡潔易懂,具有一定的學(xué)習(xí)價值,感興趣的小伙伴可以了解下
    2023-10-10
  • python實現(xiàn)自動化辦公郵件合并功能

    python實現(xiàn)自動化辦公郵件合并功能

    這篇文章主要介紹了python實現(xiàn)自動化辦公郵件合并功能,本文通過實例代碼給大家介紹的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2021-07-07

最新評論