Python中的Dataset和Dataloader詳解
Dataset,Dataloader是什么?
- Dataset:負責可被Pytorch使用的數(shù)據(jù)集的創(chuàng)建
- Dataloader:向模型中傳遞數(shù)據(jù)
為什么要了解Dataloader
? 因為你的神經(jīng)網(wǎng)絡(luò)表現(xiàn)不佳的主要原因之一可能是由于數(shù)據(jù)不佳或理解不足。
因此,以更直觀的方式理解、預(yù)處理數(shù)據(jù)并將其加載到網(wǎng)絡(luò)中非常重要。
? 通常,我們在默認或知名數(shù)據(jù)集(如 MNIST 或 CIFAR)上訓練神經(jīng)網(wǎng)絡(luò),可以輕松地實現(xiàn)預(yù)測和分類類型問題的超過 90% 的準確度。
但是那是因為這些數(shù)據(jù)集組織整齊且易于預(yù)處理。
但是處理自己的數(shù)據(jù)集時,我們常常無法達到這樣高的準確率
Dataloader 的使用
載入相關(guān)類
from torch.utils.data import Dataloader
設(shè)置相關(guān)參數(shù)
from torch.utils.data import DataLoader DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=None, pin_memory=False, ) """ dataset:是數(shù)據(jù)集 batch_size:是指一次迭代中使用的訓練樣本數(shù)。通常我們將數(shù)據(jù)分成訓練集和測試集,并且我們可能有不同的批量大小。 shuffle:是傳遞給 DataLoader 類的另一個參數(shù)。該參數(shù)采用布爾值(真/假)。如果 shuffle 設(shè)置為 True,則所有樣本都被打亂并分批加載。否則,它們會被一個接一個地發(fā)送,而不會進行任何洗牌。 num_workers:允許多處理來增加同時運行的進程數(shù) collate_fn:合并數(shù)據(jù)集 pin_memory:鎖頁內(nèi)存:將張量固定在內(nèi)存中 """
以minist為例子
# Import MNIST from torchvision.datasets import MNIST # Download and Save MNIST data_train = MNIST('~/mnist_data', train=True, download=True) # Print Data print(data_train) print(data_train[12]) #Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)
現(xiàn)在讓嘗試提取元組,其中第一個值對應(yīng)于圖像,第二個值對應(yīng)于其各自的標簽。
下面是代碼片段:
import matplotlib.pyplot as plt random_image = data_train[0][0] random_image_label = data_train[0][1] # Print the Image using Matplotlib plt.imshow(random_image) print("The label of the image is:", random_image_label)
讓我們使用 DataLoader 類來加載數(shù)據(jù)集,如下所示。
import torch from torchvision import transforms data_train = torch.utils.data.DataLoader( MNIST( '~/mnist_data', train=True, download=True, transform = transforms.Compose([ transforms.ToTensor() ])), batch_size=64, shuffle=True ) for batch_idx, samples in enumerate(data_train): print(batch_idx, samples)
這就是我們使用 DataLoader 加載簡單數(shù)據(jù)集的方式。 但是,我們不能總是對每個數(shù)據(jù)集都依賴已經(jīng)有的數(shù)據(jù)集,要是自己的數(shù)據(jù)集怎么辦。
定義自己的數(shù)據(jù)集
我們將創(chuàng)建一個由數(shù)字和文本組成的簡單自定義數(shù)據(jù)集
先介紹兩個方法
#__getitem__() 方法通過索引返回數(shù)據(jù)集中選定的樣本。 #__len__() 方法返回數(shù)據(jù)集的總大小。例如,如果您的數(shù)據(jù)集包含 1,00,000 個樣本,則 len 方法應(yīng)返回 1,00,000。 class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
? 創(chuàng)建自定義數(shù)據(jù)集并不復(fù)雜,但作為加載數(shù)據(jù)的典型過程的附加步驟,有必要構(gòu)建一個接口以獲得良好的抽象(至少可以說是一個很好的語法糖)。
現(xiàn)在我們將創(chuàng)建一個包含數(shù)字及其平方值的新數(shù)據(jù)集。 讓我們將數(shù)據(jù)集稱為 SquareDataset。 其目的是返回 [a,b] 范圍內(nèi)的值的平方。
下面是相關(guān)代碼:
import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms class SquareDataset(Dataset): def __init__(self, a=0, b=1): super(Dataset, self).__init__() assert a <= b self.a = a self.b = b def __len__(self): return self.b - self.a + 1 def __getitem__(self, index): assert self.a <= index <= self.b return index, index**2 data_train = SquareDataset(a=1,b=64) data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True) print(len(data_train))
? 在上面的代碼塊中,我們創(chuàng)建了一個名為 SquareDataset 的 Python 類,它繼承了 PyTorch 的 Dataset 類。
接下來,我們調(diào)用了一個 init() 構(gòu)造函數(shù),其中 a 和 b 分別被初始化為 0 和 1。 超類用于從繼承的 Dataset 類中訪問 len 和 get_item 方法。
接下來我們使用 assert 語句來檢查 a 是否小于或等于 b,因為我們想要創(chuàng)建一個數(shù)據(jù)集,其中值將位于 a 和 b 之間。
? 然后,我們使用 SquareDataset 類創(chuàng)建了一個數(shù)據(jù)集,其中數(shù)據(jù)值的范圍為 1 到 64。我們將其加載到名為 data_train 的變量中。
最后,Dataloader 類在 data_train_loader 中存儲的數(shù)據(jù)上創(chuàng)建了一個迭代器,batch_size 初始化為 64,shuffle 設(shè)置為 True。
如何使用transform
? 當你學會怎么定義自己的數(shù)據(jù)集的時候,你可能會想要更近 一步的操作,對于你自己的數(shù)據(jù)集進行剪切或者變換
? 以CIFAR10為例子
- 將所有圖像調(diào)整為 32×32
- 對圖像應(yīng)用中心裁剪變換
- 將裁剪后的圖像轉(zhuǎn)換為張量
- 標準化圖像
導入必要的模塊
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np
接下來,我們將定義一個名為 transforms 的變量,我們在其中按順序編寫所有預(yù)處理步驟。我們使用 Compose 類將所有轉(zhuǎn)換操作鏈接在一起。
transform = transforms.Compose([ # resize transforms.Resize(32), # center-crop transforms.CenterCrop(32), # to-tensor transforms.ToTensor(), # normalize transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) """ resize:此調(diào)整大小轉(zhuǎn)換將所有圖像轉(zhuǎn)換為定義的大小。在這種情況下,我們要將所有圖像的大小調(diào)整為 32×32。因此,我們將 32 作為參數(shù)傳遞。 center-crop:接下來,我們使用 CenterCrop 變換裁剪圖像。 我們發(fā)送的參數(shù)也是分辨率/大小,但由于我們已經(jīng)將圖像大小調(diào)整為 32x32,因此圖像將與此裁剪中心對齊。 這意味著圖像將從中心裁剪 32 個單位(垂直和水平)。 to-tensor:我們使用 ToTensor() 方法將圖像轉(zhuǎn)換為張量數(shù)據(jù)類型。 normalize:這將張量中的所有值歸一化,使它們位于 0.5 和 1 之間。 """
在下一步中,在執(zhí)行我們剛剛定義的轉(zhuǎn)換之后,我們將使用 trainloader 將 CIFAR 數(shù)據(jù)集加載到訓練集中。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False)
到此這篇關(guān)于Python中的Dataset和Dataloader詳解的文章就介紹到這了,更多相關(guān)Dataset和Dataloader詳解內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實現(xiàn)多條件篩選目標數(shù)據(jù)功能【測試可用】
這篇文章主要介紹了Python實現(xiàn)多條件篩選目標數(shù)據(jù)功能,結(jié)合實例形式總結(jié)分析了Python3使用內(nèi)建函數(shù)filter、pandas包以及for循環(huán)三種方法對比分析了列表進行條件篩選操作相關(guān)實現(xiàn)技巧與運行效率,需要的朋友可以參考下2018-06-06深入解析Python中BeautifulSoup4的基礎(chǔ)知識與實戰(zhàn)應(yīng)用
BeautifulSoup4正是一款功能強大的解析器,能夠輕松解析HTML和XML文檔,本文將介紹BeautifulSoup4的基礎(chǔ)知識,并通過實際代碼示例進行演示,感興趣的可以了解下2024-02-02Tensorflow 實現(xiàn)線性回歸模型的示例代碼
這篇文章主要介紹了Tensorflow 實現(xiàn)線性回歸模型,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2022-05-05Python中關(guān)于面向?qū)ο笾欣^承的詳細講解
面向?qū)ο缶幊?(OOP) 語言的一個主要功能就是“繼承”。繼承是指這樣一種能力:它可以使用現(xiàn)有類的所有功能,并在無需重新編寫原來的類的情況下對這些功能進行擴展2021-10-10淺談多卡服務(wù)器下隱藏部分 GPU 和 TensorFlow 的顯存使用設(shè)置
這篇文章主要介紹了淺談多卡服務(wù)器下隱藏部分 GPU 和 TensorFlow 的顯存使用設(shè)置,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06