Pytorch中DataLoader的使用方法詳解
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個(gè)函數(shù),用來(lái)處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。
一:dataset類構(gòu)建。
在構(gòu)建數(shù)據(jù)集類時(shí),除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個(gè)方法,這三個(gè)是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。
class dataset: def __init__(self,...): ... def __len__(self,...): return n def __getitem__(self,item): return data[item]
正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實(shí)際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回?cái)?shù)據(jù)集中數(shù)據(jù)個(gè)數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。
二:DataLoader使用
在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:
1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。
2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時(shí)我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個(gè)batch進(jìn)行訓(xùn)練。
3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。
4.drop_last:是否舍去最后一個(gè)batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batch size不整除,導(dǎo)致最后一個(gè)batch不為batch size)。True為舍去,F(xiàn)alse反之。
三:舉例
兔兔以指標(biāo)為1,數(shù)據(jù)個(gè)數(shù)為100的數(shù)據(jù)為例。
import torch from torch.utils.data import DataLoader class dataset: def __init__(self): self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32) self.y=(torch.sin(self.x)+1)/2 def __len__(self): return 100 def __getitem__(self, item): return self.x[item],self.y[item] data=DataLoader(dataset(),batch_size=10,shuffle=True) for batch in data: print(batch)
當(dāng)然,利用這個(gè)數(shù)據(jù)集可以進(jìn)行簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。
from torch import nn data=DataLoader(dataset(),batch_size=10,shuffle=True) bp=nn.Sequential(nn.Linear(1,5), nn.Sigmoid(), nn.Linear(5,1), nn.Sigmoid()) optim=torch.optim.Adam(params=bp.parameters()) Loss=nn.MSELoss() for epoch in range(10): print('the {} epoch'.format(epoch)) for batch in data: yp=bp(batch[0]) loss=Loss(yp,batch[1]) optim.zero_grad() loss.backward() optim.step()
ps:下面再給大家補(bǔ)充介紹下Pytorch中DataLoader的使用。
前言
最近開(kāi)始接觸pytorch,從跑別人寫好的代碼開(kāi)始,今天需要把輸入數(shù)據(jù)根據(jù)每個(gè)batch的最長(zhǎng)輸入數(shù)據(jù),填充到一樣的長(zhǎng)度(之前是將所有的數(shù)據(jù)直接填充到一樣的長(zhǎng)度再輸入)。
剛開(kāi)始是想偷懶,沒(méi)有去認(rèn)真了解輸入的機(jī)制,結(jié)果一直報(bào)錯(cuò)…還是要認(rèn)真學(xué)習(xí)呀!
加載數(shù)據(jù)
pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個(gè)dataset對(duì)象
②創(chuàng)建一個(gè)dataloader對(duì)象
③循環(huán)dataloader對(duì)象,將data,label拿到模型中去訓(xùn)練
dataset
你需要自己定義一個(gè)class,里面至少包含3個(gè)函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個(gè)數(shù)據(jù)集一共有多少個(gè)item
③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor
import torch from torch.utils.data import Dataset class Mydata(Dataset): def __init__(self): a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True) b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True) d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True) c = np.load("D:/Python/nlp/NRE/c.npy") self.x = list(zip(a,b,d,c)) def __getitem__(self, idx): assert idx < len(self.x) return self.x[idx] def __len__(self): return len(self.x)
dataloader
參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle = True:是否打亂數(shù)據(jù)
collate_fn:使用這個(gè)參數(shù)可以自己操作每個(gè)batch的數(shù)據(jù)
dataset = Mydata() dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是將每個(gè)batch的數(shù)據(jù)填充到該batch的最大長(zhǎng)度
def mycollate(data): a = [] b = [] c = [] d = [] max_len = len(data[0][0]) for i in data: if len(i[0])>max_len: max_len = len(i[0]) if len(i[1])>max_len: max_len = len(i[1]) if len(i[2])>max_len: max_len = len(i[2]) print(max_len) # 填充 for i in data: if len(i[0])<max_len: i[0].extend([27] * (max_len-len(i[0]))) if len(i[1])<max_len: i[1].extend([27] * (max_len-len(i[1]))) if len(i[2])<max_len: i[2].extend([27] * (max_len-len(i[2]))) a.append(i[0]) b.append(i[1]) d.append(i[2]) c.extend(i[3]) # 這里要自己轉(zhuǎn)成tensor a = torch.Tensor(a) b = torch.Tensor(b) c = torch.Tensor(c) d = torch.Tensor(d) data1 = [a,b,d,c] print("data1",data1) return data1
結(jié)果:
最后循環(huán)該dataloader ,拿到數(shù)據(jù)放入模型進(jìn)行訓(xùn)練:
for ii, data in enumerate(test_data_loader): if opt.use_gpu: data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data)) else: data = list(map(lambda x: torch.LongTensor(x.long()), data)) out = model(data[:-1]) #數(shù)據(jù)data[:-1] loss = F.cross_entropy(out, data[-1])# 最后一列是標(biāo)簽
寫在最后:建議像我一樣剛開(kāi)始不太熟練的小伙伴,在處理數(shù)據(jù)輸入的時(shí)候可以打印出來(lái)仔細(xì)查看。
到此這篇關(guān)于Pytorch中DataLoader的使用方法的文章就介紹到這了,更多相關(guān)Pytorch DataLoader內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
通過(guò)VS下載的NuGet包修改其下載存放路徑的操作方法
這篇文章主要介紹了通過(guò)VS下載的NuGet包如何修改其下載存放路徑,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-09-09Python實(shí)現(xiàn)簡(jiǎn)易端口掃描器代碼實(shí)例
本篇文章主要介紹了Python實(shí)現(xiàn)簡(jiǎn)易端口掃描器的相關(guān)代碼,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下。2017-03-03K-近鄰算法的python實(shí)現(xiàn)代碼分享
這篇文章主要介紹了K-近鄰算法的python實(shí)現(xiàn)代碼分享,具有一定借鑒價(jià)值,需要的朋友可以參考下。2017-12-12python使用PyCharm進(jìn)行遠(yuǎn)程開(kāi)發(fā)和調(diào)試
這篇文章主要介紹了python使用PyCharm進(jìn)行遠(yuǎn)程開(kāi)發(fā)和調(diào)試,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-11-11python3?cookbook解壓可迭代對(duì)象賦值給多個(gè)變量的問(wèn)題及解決方案
這篇文章主要介紹了python3?cookbook-解壓可迭代對(duì)象賦值給多個(gè)變量,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2024-01-01