欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中DataLoader的使用方法詳解

 更新時(shí)間:2022年09月08日 10:34:01   作者:生信小兔  
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個(gè)函數(shù),用來(lái)處理加載數(shù)據(jù)集,這篇文章主要介紹了Pytorch中DataLoader的使用方法,需要的朋友可以參考下

在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數(shù)據(jù)集的兩個(gè)函數(shù),用來(lái)處理加載數(shù)據(jù)集。通常情況下,使用的關(guān)鍵在于構(gòu)建dataset類。

一:dataset類構(gòu)建。

在構(gòu)建數(shù)據(jù)集類時(shí),除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個(gè)方法,這三個(gè)是必不可少的,至于其它用于數(shù)據(jù)處理的函數(shù),可以任意定義。

class dataset:
    def __init__(self,...):
        ...
    def __len__(self,...):
        return n
    def __getitem__(self,item):
        return data[item]

正常情況下,該數(shù)據(jù)集是要繼承Pytorch中Dataset類的,但實(shí)際操作中,即使不繼承,數(shù)據(jù)集類構(gòu)建后仍可以用Dataloader()加載的。

在dataset類中,__len__(self)返回?cái)?shù)據(jù)集中數(shù)據(jù)個(gè)數(shù),__getitem__(self,item)表示每次返回第item條數(shù)據(jù)。

二:DataLoader使用

在構(gòu)建dataset類后,即可使用DataLoader加載。DataLoader中常用參數(shù)如下:

1.dataset:需要載入的數(shù)據(jù)集,如前面構(gòu)造的dataset類。

2.batch_size:批大小,在神經(jīng)網(wǎng)絡(luò)訓(xùn)練時(shí)我們很少逐條數(shù)據(jù)訓(xùn)練,而是幾條數(shù)據(jù)作為一個(gè)batch進(jìn)行訓(xùn)練。

3.shuffle:是否在打亂數(shù)據(jù)集樣本順序。True為打亂,F(xiàn)alse反之。

4.drop_last:是否舍去最后一個(gè)batch的數(shù)據(jù)(很多情況下數(shù)據(jù)總數(shù)N與batch size不整除,導(dǎo)致最后一個(gè)batch不為batch size)。True為舍去,F(xiàn)alse反之。

三:舉例

兔兔以指標(biāo)為1,數(shù)據(jù)個(gè)數(shù)為100的數(shù)據(jù)為例。

import torch
from torch.utils.data import DataLoader
 
class dataset:
    def __init__(self):
        self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
        self.y=(torch.sin(self.x)+1)/2
    def __len__(self):
        return 100
    def __getitem__(self, item):
        return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
    print(batch)

當(dāng)然,利用這個(gè)數(shù)據(jù)集可以進(jìn)行簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)訓(xùn)練。

from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
                 nn.Sigmoid(),
                 nn.Linear(5,1),
                 nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
    print('the {} epoch'.format(epoch))
    for batch in data:
        yp=bp(batch[0])
        loss=Loss(yp,batch[1])
        optim.zero_grad()
        loss.backward()
        optim.step()

ps:下面再給大家補(bǔ)充介紹下Pytorch中DataLoader的使用。

前言

最近開(kāi)始接觸pytorch,從跑別人寫好的代碼開(kāi)始,今天需要把輸入數(shù)據(jù)根據(jù)每個(gè)batch的最長(zhǎng)輸入數(shù)據(jù),填充到一樣的長(zhǎng)度(之前是將所有的數(shù)據(jù)直接填充到一樣的長(zhǎng)度再輸入)。
剛開(kāi)始是想偷懶,沒(méi)有去認(rèn)真了解輸入的機(jī)制,結(jié)果一直報(bào)錯(cuò)…還是要認(rèn)真學(xué)習(xí)呀!

加載數(shù)據(jù)

pytorch中加載數(shù)據(jù)的順序是:
①創(chuàng)建一個(gè)dataset對(duì)象
②創(chuàng)建一個(gè)dataloader對(duì)象
③循環(huán)dataloader對(duì)象,將data,label拿到模型中去訓(xùn)練

dataset

你需要自己定義一個(gè)class,里面至少包含3個(gè)函數(shù):
①__init__:傳入數(shù)據(jù),或者像下面一樣直接在函數(shù)里加載數(shù)據(jù)
②__len__:返回這個(gè)數(shù)據(jù)集一共有多少個(gè)item
③__getitem__:返回一條訓(xùn)練數(shù)據(jù),并將其轉(zhuǎn)換成tensor

import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
    def __init__(self):
        a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
        b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
        d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
        c = np.load("D:/Python/nlp/NRE/c.npy")
        self.x = list(zip(a,b,d,c))
    def __getitem__(self, idx):
        
        assert idx < len(self.x)
        return self.x[idx]
    def __len__(self):
        
        return len(self.x)

dataloader

參數(shù):
dataset:傳入的數(shù)據(jù)
shuffle = True:是否打亂數(shù)據(jù)
collate_fn:使用這個(gè)參數(shù)可以自己操作每個(gè)batch的數(shù)據(jù)

dataset = Mydata()
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)

下面是將每個(gè)batch的數(shù)據(jù)填充到該batch的最大長(zhǎng)度

def mycollate(data):
        a = []
        b = []
        c = []
        d = []
        max_len = len(data[0][0])
        for i in data:
            if len(i[0])>max_len:
                max_len = len(i[0])
            if len(i[1])>max_len:
                max_len = len(i[1])
            if len(i[2])>max_len:
                max_len = len(i[2])
        print(max_len)
        # 填充
        for i in data:
            if len(i[0])<max_len:
                i[0].extend([27] * (max_len-len(i[0])))
            if len(i[1])<max_len:
                i[1].extend([27] * (max_len-len(i[1])))
            if len(i[2])<max_len:
                i[2].extend([27] * (max_len-len(i[2])))  
            a.append(i[0])
            b.append(i[1])
            d.append(i[2])
            c.extend(i[3])
        # 這里要自己轉(zhuǎn)成tensor
        a = torch.Tensor(a)
        b = torch.Tensor(b)
        c = torch.Tensor(c)
        d = torch.Tensor(d)
        data1 = [a,b,d,c]
        print("data1",data1)
        return data1

結(jié)果:

在這里插入圖片描述

最后循環(huán)該dataloader ,拿到數(shù)據(jù)放入模型進(jìn)行訓(xùn)練:

 for ii, data in enumerate(test_data_loader):

        if opt.use_gpu: 
            data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data)) 
        else: 
            data = list(map(lambda x: torch.LongTensor(x.long()), data))

        out = model(data[:-1]) #數(shù)據(jù)data[:-1]
        loss = F.cross_entropy(out, data[-1])# 最后一列是標(biāo)簽

寫在最后:建議像我一樣剛開(kāi)始不太熟練的小伙伴,在處理數(shù)據(jù)輸入的時(shí)候可以打印出來(lái)仔細(xì)查看。

到此這篇關(guān)于Pytorch中DataLoader的使用方法的文章就介紹到這了,更多相關(guān)Pytorch DataLoader內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • 在Python中操作列表之List.pop()方法的使用

    在Python中操作列表之List.pop()方法的使用

    這篇文章主要介紹了在Python中操作列表之List.pop()方法的使用,是Python入門中的基礎(chǔ)知識(shí),尤其該方法的返回值在Python編程中經(jīng)常被靈活運(yùn)用,需要的朋友可以參考下
    2015-05-05
  • python  logging日志打印過(guò)程解析

    python logging日志打印過(guò)程解析

    這篇文章主要介紹了python logging日志打印過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-10-10
  • 詳解基于python的圖像Gabor變換及特征提取

    詳解基于python的圖像Gabor變換及特征提取

    這篇文章主要介紹了基于python的圖像Gabor變換及特征提取,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧
    2020-10-10
  • 通過(guò)VS下載的NuGet包修改其下載存放路徑的操作方法

    通過(guò)VS下載的NuGet包修改其下載存放路徑的操作方法

    這篇文章主要介紹了通過(guò)VS下載的NuGet包如何修改其下載存放路徑,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-09-09
  • Python實(shí)現(xiàn)簡(jiǎn)易端口掃描器代碼實(shí)例

    Python實(shí)現(xiàn)簡(jiǎn)易端口掃描器代碼實(shí)例

    本篇文章主要介紹了Python實(shí)現(xiàn)簡(jiǎn)易端口掃描器的相關(guān)代碼,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下。
    2017-03-03
  • Python中的eval()函數(shù)使用詳解

    Python中的eval()函數(shù)使用詳解

    這篇文章主要介紹了Python中的eval()函數(shù)使用詳解,eval()函數(shù)是用來(lái)執(zhí)行一個(gè)字符串表達(dá)式,并返回表達(dá)式的值,可以把字符串轉(zhuǎn)化為list,dict ,tuple,需要的朋友可以參考下
    2023-12-12
  • Python中的迭代器詳解

    Python中的迭代器詳解

    這篇文章主要介紹迭代器,看完文章你可以了解到什么是可迭代對(duì)象、啥是迭代器、如何自定義迭代器、使用迭代器的優(yōu)勢(shì),文中有詳細(xì)的代碼示例,需要的朋友可以參考下
    2023-08-08
  • K-近鄰算法的python實(shí)現(xiàn)代碼分享

    K-近鄰算法的python實(shí)現(xiàn)代碼分享

    這篇文章主要介紹了K-近鄰算法的python實(shí)現(xiàn)代碼分享,具有一定借鑒價(jià)值,需要的朋友可以參考下。
    2017-12-12
  • python使用PyCharm進(jìn)行遠(yuǎn)程開(kāi)發(fā)和調(diào)試

    python使用PyCharm進(jìn)行遠(yuǎn)程開(kāi)發(fā)和調(diào)試

    這篇文章主要介紹了python使用PyCharm進(jìn)行遠(yuǎn)程開(kāi)發(fā)和調(diào)試,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧
    2017-11-11
  • python3?cookbook解壓可迭代對(duì)象賦值給多個(gè)變量的問(wèn)題及解決方案

    python3?cookbook解壓可迭代對(duì)象賦值給多個(gè)變量的問(wèn)題及解決方案

    這篇文章主要介紹了python3?cookbook-解壓可迭代對(duì)象賦值給多個(gè)變量,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2024-01-01

最新評(píng)論