PyTorch實(shí)現(xiàn)重寫/改寫Dataset并載入Dataloader
前言
眾所周知,Dataset和Dataloder是pytorch中進(jìn)行數(shù)據(jù)載入的部件。必須將數(shù)據(jù)載入后,再進(jìn)行深度學(xué)習(xí)模型的訓(xùn)練。在pytorch的一些案例教學(xué)中,常使用torchvision.datasets
自帶的MNIST、CIFAR-10數(shù)據(jù)集,一般流程為:
# 下載并存放數(shù)據(jù)集 train_dataset = torchvision.datasets.CIFAR10(root="數(shù)據(jù)集存放位置",download=True) # load數(shù)據(jù) train_loader = torch.utils.data.DataLoader(dataset=train_dataset)
但是,在我們自己的模型訓(xùn)練中,需要使用非官方自制的數(shù)據(jù)集。這時(shí)應(yīng)該怎么辦呢?
我們可以通過改寫torch.utils.data.Dataset
中的__getitem__
和__len__
來載入我們自己的數(shù)據(jù)集。
__getitem__
獲取數(shù)據(jù)集中的數(shù)據(jù),__len__
獲取整個(gè)數(shù)據(jù)集的長度(即個(gè)數(shù))。
改寫
采用pytorch官網(wǎng)案例中提供的一個(gè)臉部landmark數(shù)據(jù)集。數(shù)據(jù)集中含有存放landmark的csv文件,但是我們在這篇文章中不使用(其實(shí)也可以隨便下載一些圖片作數(shù)據(jù)集來實(shí)驗(yàn))。
import os import torch from skimage import io, transform import numpy as np import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils plt.ion() # interactive mode
torch.utils.data.Dataset
是一個(gè)抽象類,我們自己的數(shù)據(jù)集需要繼承Dataset
,然后改寫上述兩個(gè)函數(shù):
class ImageLoader(Dataset): def __init__(self, file_path, transform=None): super(ImageLoader,self).__init__() self.file_path = file_path self.transform = transform # 對輸入圖像進(jìn)行預(yù)處理,這里并沒有做,預(yù)設(shè)為None self.image_names = os.listdir(self.file_path) # 文件名的列表 def __getitem__(self,idx): image = self.image_names[idx] image = io.imread(os.path.join(self.file_path,image)) # if self.transform: # image= self.transform(image) return image def __len__(self): return len(self.image_names) # 設(shè)置自己存放的數(shù)據(jù)集位置,并plot展示 imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\") # imageloader.__len__() # 輸出數(shù)據(jù)集長度(個(gè)數(shù)),應(yīng)為71 # print(imageloader.__getitem__(0)) # 以數(shù)據(jù)形式展示 plt.imshow(imageloader.__getitem__(0)) # 以圖像形式展示 plt.show()
得到的圖片輸出:
得到的數(shù)據(jù)輸出,:
array([[[ 66, 59, 53], [ 66, 59, 53], [ 66, 59, 53], ..., [ 59, 54, 48], [ 59, 54, 48], [ 59, 54, 48]], ..., [153, 141, 129], [158, 146, 134], [158, 146, 134]]], dtype=uint8)
上面看到dytpe=uint8
,實(shí)際進(jìn)行訓(xùn)練的時(shí)候,常常需要更改成float
的數(shù)據(jù)類型。可以使用:
# 直接改成pytorch中的tensor下的float格式 # 也可以用numpy的改成普通的float格式 to_float= torch.from_numpy(imageloader.__getitem__(0)).float()
改寫完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)
載入到Dataloader
中,就可以使用了。
下面的代碼可以試著運(yùn)行一下,產(chǎn)生的是一模一樣的圖片結(jié)果。
train_loader = torch.utils.data.DataLoader(dataset=imageloader) train_loader.dataset[0] plt.imshow(train_loader.dataset[0]) plt.show()
到此這篇關(guān)于PyTorch實(shí)現(xiàn)重寫/改寫Dataset并載入Dataloader的文章就介紹到這了,更多相關(guān)PyTorch重寫/改寫Dataset 內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python 隨機(jī)打亂 圖片和對應(yīng)的標(biāo)簽方法
今天小編就為大家分享一篇python 隨機(jī)打亂 圖片和對應(yīng)的標(biāo)簽方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-12-12Python數(shù)據(jù)結(jié)構(gòu)與算法之列表(鏈表,linked list)簡單實(shí)現(xiàn)
這篇文章主要介紹了Python數(shù)據(jù)結(jié)構(gòu)與算法之列表(鏈表,linked list)簡單實(shí)現(xiàn),具有一定參考價(jià)值,需要的朋友可以了解下。2017-10-10python中將txt文件轉(zhuǎn)換為csv文件的三種方法舉例
對于大數(shù)據(jù)的處理基本都是以CSV文件為基礎(chǔ)進(jìn)行的,那么在進(jìn)行深度學(xué)習(xí)的處理之前,需要先統(tǒng)一數(shù)據(jù)文件的格式,下面這篇文章主要給大家介紹了關(guān)于python中將txt文件轉(zhuǎn)換為csv文件的三種方法,需要的朋友可以參考下2024-06-06基于Python的Post請求數(shù)據(jù)爬取的方法詳解
這篇文章主要介紹了基于Python的Post請求數(shù)據(jù)爬取的方法,需要的朋友可以參考下2019-06-06python3 實(shí)現(xiàn)在運(yùn)行的時(shí)候隱藏命令窗口
這篇文章主要介紹了python3 實(shí)現(xiàn)在運(yùn)行的時(shí)候隱藏命令窗口方式,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-05-05Python如何做點(diǎn)擊率數(shù)據(jù)預(yù)測
這篇文章主要介紹了Python做點(diǎn)擊率數(shù)據(jù)預(yù)測,在這個(gè)場景中,我們通常需要根據(jù)用戶的歷史行為、物品的特征、上下文信息等因素來預(yù)測用戶點(diǎn)擊某個(gè)特定物品(如廣告、推薦商品)的概率,需要的朋友可以參考下2024-06-06基于Python創(chuàng)建語音識(shí)別控制系統(tǒng)
這篇文章主要介紹了通過Python實(shí)現(xiàn)創(chuàng)建語音識(shí)別控制系統(tǒng),能利用語音識(shí)別識(shí)別說出來的文字,根據(jù)文字的內(nèi)容來控制圖形移動(dòng),感興趣的同學(xué)可以關(guān)注一下2021-12-12Python常見數(shù)據(jù)類型轉(zhuǎn)換操作示例
這篇文章主要介紹了Python常見數(shù)據(jù)類型轉(zhuǎn)換操作,結(jié)合實(shí)例形式分析了Python針對列表、集合、元組、字典等數(shù)據(jù)類型轉(zhuǎn)換的相關(guān)操作技巧,需要的朋友可以參考下2019-05-05