Pytorch使用技巧之Dataloader中的collate_fn參數(shù)詳析
以MNIST為例
from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True, download=True) print(mnist[0])
結果
(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)
MINIST數(shù)據(jù)集的dataset是由一張圖片和一個label組成的元組
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x) for each in dataloader: print(each) break
結果
[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]
collate_fn為lamda x:x時表示對傳入進來的數(shù)據(jù)不做處理
下面自定義collate_fn看看什么效果
def collate(data): img = [] label = [] for each in data: img.append(each[0]) label.append(each[1]) return img,label dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x)) for each in dataloader: print(each) break
結果
([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])
說明:若不設置collate_fn參數(shù)則會使用默認處理函數(shù)
但必須保證傳進來的數(shù)據(jù)都是tensor格式否則會報錯
附:DataLoader完整的參數(shù)表如下:
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
DataLoader在數(shù)據(jù)集上提供單進程或多進程的迭代器
幾個關鍵的參數(shù)意思:
- shuffle:設置為True的時候,每個世代都會打亂數(shù)據(jù)集
- collate_fn:如何取樣本的,我們可以定義自己的函數(shù)來準確地實現(xiàn)想要的功能
- drop_last:告訴如何處理數(shù)據(jù)集長度除于batch_size余下的數(shù)據(jù)。True就拋棄,否則保留
總結
到此這篇關于Pytorch使用技巧之Dataloader中的collate_fn參數(shù)的文章就介紹到這了,更多相關Dataloader中的collate_fn參數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
pytorch通過自己的數(shù)據(jù)集訓練Unet網(wǎng)絡架構
Unet是一個最近比較火的網(wǎng)絡結構。它的理論已經(jīng)有很多大佬在討論了。本文主要從實際操作的層面,講解如何使用pytorch實現(xiàn)unet圖像分割2022-12-12在Python中通過threshold創(chuàng)建mask方式
今天小編就為大家分享一篇在Python中通過threshold創(chuàng)建mask方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02pandas基礎?Series與Dataframe與numpy對二進制文件輸入輸出
這篇文章主要介紹了pandas基礎Series與Dataframe與numpy對二進制文件輸入輸出,series是一種一維的數(shù)組型對象,它包含了一個值序列和一個數(shù)據(jù)標簽2022-07-07