欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch?collate_fn的基礎(chǔ)與應(yīng)用教程

 更新時(shí)間:2022年02月10日 11:07:32   作者:音程  
這篇文章主要給大家介紹了關(guān)于pytorch?collate_fn基礎(chǔ)與應(yīng)用的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下

作用

collate_fn:即用于collate的function,用于整理數(shù)據(jù)的函數(shù)。

說(shuō)到整理數(shù)據(jù),你當(dāng)然要會(huì)用數(shù)據(jù),即會(huì)用數(shù)據(jù)制作工具torch.utils.data.Dataset,雖然我們今天談的是torch.utils.data.DataLoader,但是,其實(shí):

  1. 這兩個(gè)你如何定義;
  2. 從裝載器dataloader中取數(shù)據(jù)后做什么處理;
  3. 模型的forward()中如何處理。

這三部分都是有機(jī)統(tǒng)一而不可分割的,一個(gè)地方改了,其他地方就要改。

emmm…,小小總結(jié),collate_fn籠統(tǒng)的說(shuō)就是用于整理數(shù)據(jù),通常我們不需要使用,其應(yīng)用的情形是:各個(gè)數(shù)據(jù)長(zhǎng)度不一樣的情況,比如第一張圖片大小是28*28,第二張是50*50,這樣的話就如果不自己寫(xiě)collate_fn,而使用默認(rèn)的,就會(huì)報(bào)錯(cuò)。

原則

其實(shí)說(shuō)起來(lái),我們也沒(méi)有什么原則,但是如今大多數(shù)做深度學(xué)習(xí)都是使用GPU,所以這個(gè)時(shí)候我們需要記住一個(gè)總則:只有tensor數(shù)據(jù)類型才能運(yùn)行在GPU上,list和numpy都不可以。

從而,我們什么時(shí)候?qū)⑽覀兊臄?shù)據(jù)轉(zhuǎn)化為tensor是一個(gè)問(wèn)題,我的答案是前一節(jié)中的三個(gè)部分都可以來(lái)轉(zhuǎn)化,只是我們大多數(shù)的人都習(xí)慣在部分一轉(zhuǎn)化。

基礎(chǔ)

dataset

我們必須先看看torch.utils.data.Dataset如何使用,以一個(gè)例子為例:

import torch.utils.data as Data
class mydataset(Data.Dataset):
    def __init__(self,train_inputs,train_targets):#必須有
        super(mydataset,self).__init__()
        self.inputs=train_inputs
        self.targets=train_targets
        
    def __getitem__(self, index):#必須重寫(xiě)
        return self.inputs[index],self.targets[index]
        
    def __len__(self):#必須重寫(xiě)
        return len(self.targets)
#構(gòu)造訓(xùn)練數(shù)據(jù)
datax=torch.randn(4,3)#構(gòu)造4個(gè)輸入
datay=torch.empty(4).random_(2)#構(gòu)造4個(gè)標(biāo)簽
#制作dataset
dataset=mydataset(datax,datay)

下面,可以對(duì)dataset進(jìn)行一系列操作,這些操作返回的結(jié)果和你之前那個(gè)class的三個(gè)函數(shù)定義都息息相關(guān)。我想說(shuō),那三個(gè)函數(shù)非常自由,你想怎么定義就怎么定義,上述只是一種常見(jiàn)的而已,你可以定制一個(gè)特色的。

len(dataset)#調(diào)用了你上面定義的def __len__()那個(gè)函數(shù)
#4
dataset[0]#調(diào)用了你上面定義的def __getitem__()那個(gè)函數(shù)
#(tensor([-1.1426, -1.3239,  1.8372]), tensor(0.))

所以我再三強(qiáng)調(diào)的是上面的輸出結(jié)果和你的定義有關(guān),比如你完全可以把def __getitem__()改成:

    def __getitem__(self, index):
        return self.inputs[index]#不輸出標(biāo)簽

那么,

dataset[0]#此時(shí)當(dāng)然變化。
#tensor([-1.1426, -1.3239,  1.8372])

可以看到,是非常隨便的,你隨便定制就好。

dataloader

torch.utils.data.DataLoader

dataloader=Data.DataLoader(dataset,batch_size=2)

4個(gè)數(shù)據(jù),batch_size=2,所以一共有2個(gè)batch。

collate_fn如果你不指定,會(huì)調(diào)用pytorch內(nèi)部的,也就是說(shuō)這個(gè)函數(shù)是一定會(huì)調(diào)用的,而且調(diào)用這個(gè)函數(shù)時(shí)pytorch會(huì)往這個(gè)函數(shù)里面?zhèn)魅胍粋€(gè)參數(shù)batch。

def my_collate(batch):
	return xxx

這個(gè)batch是什么?這個(gè)東西和你定義的dataset, batch_size息息相關(guān)。batch是一個(gè)列表[x,...,xx],長(zhǎng)度就是batch_size,里面每一個(gè)元素是dataset的某一個(gè)元素,即dataset[i](我在上一節(jié)展示過(guò)dataset[0])。

在我們的例子中,由于我們沒(méi)有對(duì)dataloader設(shè)置需要打亂數(shù)據(jù),即shuffle=True,那么第1個(gè)batch就是前兩個(gè)數(shù)據(jù),如下:

print(datax)
print(datay)
batch=[dataset[0],dataset[1]]#所以才說(shuō)和你dataset中g(shù)et_item的定義有關(guān)。
print(batch)

對(duì),你沒(méi)有看錯(cuò),上述代碼展示的batch就會(huì)傳入到pytorch默認(rèn)的collate_fn中,然后經(jīng)過(guò)默認(rèn)的處理,輸出如下:

it=iter(dataloader)
nex=next(it)#我們展示第一個(gè)batch經(jīng)過(guò)collate_fn之后的輸出結(jié)果
print(nex)

其實(shí),上面就是我們常用的,經(jīng)典的輸出結(jié)果,即輸入和標(biāo)簽是分開(kāi)的,第一項(xiàng)是輸入tensor,第二項(xiàng)是標(biāo)簽tensor,輸入的維度變成了(batch_size,input_size)。

但是我們乍一看,將第一個(gè)batch變成上述輸出結(jié)果很容易呀,我們也會(huì)!我們下面就來(lái)自己寫(xiě)一個(gè)collate_fn實(shí)現(xiàn)這個(gè)功能。

# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is input tensor and the second element is corresponding label
def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    target = torch.tensor([data[1] for data in batch])
    return [data, target]

 

dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax)
print(datay)

it=iter(dataloader)
nex=next(it)
print(nex)

這不就和默認(rèn)的collate_fn的輸出結(jié)果一樣了嘛!無(wú)非就是默認(rèn)的還把輸入變成了tensor,標(biāo)簽變成了tensor,我上面是列表,我改就是了嘛!如下:

def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    inputs=torch.tensor(inputs)
    target =[data[1].tolist() for data in batch]
    target=torch.tensor(target)
    return [inputs, target]
    
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader)
nex=next(it)
print(nex)

這下好了吧!

對(duì)了,作為彩蛋,告訴大家一個(gè)秘密:默認(rèn)的collate_fn函數(shù)中有一些語(yǔ)句是轉(zhuǎn)tensor以及tensor合并的操作,所以你的dataset如果沒(méi)有設(shè)計(jì)成經(jīng)典模式的話,使用默認(rèn)的就容易報(bào)錯(cuò),而我們自己會(huì)寫(xiě)collate_fn,當(dāng)然就不存在這個(gè)問(wèn)題啦。同時(shí),給大家的一個(gè)經(jīng)驗(yàn)就是,一般dataset是不會(huì)報(bào)錯(cuò)的,而是根據(jù)dataset制作dataloader的時(shí)候容易報(bào)錯(cuò),因?yàn)槟J(rèn)collate_fn把dataset的類型限制得比較死。

應(yīng)用情形

假設(shè)我們還是4個(gè)輸入,但是維度不固定的。

a=[[1,2],[3,4,5],[1],[3,4,9]]
b=[1,0,0,1]
dataset=mydataset(a,b)
dataloader=Data.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex

使用默認(rèn)的collate_fn,直接報(bào)錯(cuò),要求相同維度。

這個(gè)時(shí)候,我們可以使用自己的collate_fn,避免報(bào)錯(cuò)。

不過(guò)話說(shuō)回來(lái),我個(gè)人感受是:

在這里避免報(bào)錯(cuò)好像也沒(méi)有什么用,因?yàn)榇蠖鄶?shù)的神經(jīng)網(wǎng)絡(luò)都是定長(zhǎng)輸入的,而且很多的操作也要求相同維度才能相加或相乘,所以:這里不報(bào)錯(cuò),后面還是報(bào)錯(cuò)。如果后面解決這個(gè)問(wèn)題的方法是:在不足維度上進(jìn)行補(bǔ)0操作,那么我們?yōu)槭裁床辉诮ataset之前先補(bǔ)好呢?所以,collate_fn這個(gè)東西的應(yīng)用場(chǎng)景還是有限的。

總結(jié)

到此這篇關(guān)于pytorch collate_fn的基礎(chǔ)與應(yīng)用的文章就介紹到這了,更多相關(guān)pytorch collate_fn應(yīng)用內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • 詳解如何用Python模擬登錄淘寶

    詳解如何用Python模擬登錄淘寶

    最近想爬取淘寶的一些商品,但是發(fā)現(xiàn)如果要使用搜索等一些功能時(shí)基本都需要登錄,所以就想出一篇模擬登錄淘寶的文章!本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-08-08
  • python3實(shí)現(xiàn)點(diǎn)餐系統(tǒng)

    python3實(shí)現(xiàn)點(diǎn)餐系統(tǒng)

    這篇文章主要為大家詳細(xì)介紹了python3實(shí)現(xiàn)點(diǎn)餐系統(tǒng),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-01-01
  • Numpy之reshape()使用詳解

    Numpy之reshape()使用詳解

    今天小編就為大家分享一篇Numpy之reshape()使用詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-12-12
  • PyTorch使用GPU訓(xùn)練的兩種方法實(shí)例

    PyTorch使用GPU訓(xùn)練的兩種方法實(shí)例

    pytorch是一個(gè)非常優(yōu)秀的深度學(xué)習(xí)的框架,具有速度快,代碼簡(jiǎn)潔,可讀性強(qiáng)的優(yōu)點(diǎn),下面這篇文章主要給大家介紹了關(guān)于PyTorch使用GPU訓(xùn)練的兩種方法,需要的朋友可以參考下
    2022-05-05
  • 表格梳理解析python內(nèi)置時(shí)間模塊看完就懂

    表格梳理解析python內(nèi)置時(shí)間模塊看完就懂

    這篇文章主要介紹了python內(nèi)置的時(shí)間模塊,本文用表格方式清晰的對(duì)Python內(nèi)置時(shí)間模塊進(jìn)行語(yǔ)法及用法的梳理解析,有需要的朋友建議收藏參考
    2021-10-10
  • 利用Python實(shí)現(xiàn)自定義連點(diǎn)器

    利用Python實(shí)現(xiàn)自定義連點(diǎn)器

    這篇文章主要介紹了如何利用Python實(shí)現(xiàn)自定義連點(diǎn)器,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-08-08
  • Python曲線擬合詳解

    Python曲線擬合詳解

    這篇文章主要介紹了關(guān)于python曲線擬合,scipy.optimize中,curve_fit函數(shù)可調(diào)用非線性最小二乘法進(jìn)行函數(shù)擬合,文中有詳細(xì)的代碼作為參考,需要的朋友可以閱讀參考
    2023-04-04
  • 基于tf.shape(tensor)和tensor.shape()的區(qū)別說(shuō)明

    基于tf.shape(tensor)和tensor.shape()的區(qū)別說(shuō)明

    這篇文章主要介紹了基于tf.shape(tensor)和tensor.shape()的區(qū)別說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2020-06-06
  • 對(duì)Keras中predict()方法和predict_classes()方法的區(qū)別說(shuō)明

    對(duì)Keras中predict()方法和predict_classes()方法的區(qū)別說(shuō)明

    這篇文章主要介紹了對(duì)Keras中predict()方法和predict_classes()方法的區(qū)別說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2020-06-06
  • 跟老齊學(xué)Python之大話題小函數(shù)(2)

    跟老齊學(xué)Python之大話題小函數(shù)(2)

    上篇文章我們講訴了map 和lambda函數(shù)的使用,本文我們繼續(xù)來(lái)看看reduce和filter函數(shù),有需要的朋友可以參考下
    2014-10-10

最新評(píng)論