Pytorch中TensorDataset與DataLoader的使用方式
TensorDataset與DataLoader的使用
TensorDataset
TensorDataset本質(zhì)上與python zip方法類似,對數(shù)據(jù)進行打包整合。
官方文檔說明:
**Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.*
Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.
該類通過每一個 tensor 的第一個維度進行索引。
因此,該類中的 tensor 第一維度必須相等。
import torch from torch.utils.data import TensorDataset # a的形狀為(4*3) a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) # b的第一維與a相同 b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) print(train_data[0:4])
輸出結(jié)果如下:
(tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]]), tensor([1, 2, 3, 4]))
DataLoader
DataLoader本質(zhì)上就是一個iterable(跟python的內(nèi)置類型list等一樣),并利用多進程來加速batch data的處理,使用yield來使用有限的內(nèi)存。
import torch from torch.utils.data import TensorDataset from torch.utils.data import DataLoader a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) data = DataLoader(train_data, batch_size=2, shuffle=True) for i, j in enumerate(data): ? ? x, y = j ? ? print(' batch:{0} x:{1} ?y: {2}'.format(i, x, y))
輸出:
batch:0 x:tensor([[1, 1, 1],
[2, 2, 2]]) y: tensor([1, 2])
batch:1 x:tensor([[4, 4, 4],
[3, 3, 3]]) y: tensor([4, 3])
Pytorch Dataset,TensorDataset,Dataloader,Sampler關(guān)系
Dataloader
Dataloader是數(shù)據(jù)加載器,組合數(shù)據(jù)集和采樣器,并在數(shù)據(jù)集上提供單線程或多線程的迭代器。
所以Dataloader的參數(shù)必然需要指定數(shù)據(jù)集Dataset和采樣器Sampler。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset (Dataset) – 數(shù)據(jù)集。
- batch_size (int, optional) – 每個batch加載樣本數(shù)。
- shuffle (bool, optional) – True則打亂數(shù)據(jù).
- sampler (Sampler, optional) – 采樣器,如指定則忽略shuffle參數(shù)。
- num_workers (int, optional) – 用多少個子進程加載數(shù)據(jù)。0表示數(shù)據(jù)將在主進程中加載
- collate_fn (callable, optional) – 獲取batch數(shù)據(jù)的回調(diào)函數(shù),也就是說可以在這個函數(shù)中修改batch的形式
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果數(shù)據(jù)集大小不能被batch size整除,則設(shè)置為True后可刪除最后一個不完整的batch。如果設(shè)為False并且數(shù)據(jù)集的大小不能被batch size整除,則最后一個batch將更小。
Dataset和TensorDataset
所有其他數(shù)據(jù)集都應(yīng)該進行子類化。所有子類應(yīng)該override __len__
和 __getitem__
,前者提供了數(shù)據(jù)集的大小,后者支持整數(shù)索引,范圍從0到len(self)。
TensorDataset是Dataset的子類,已經(jīng)復(fù)寫了 __len__
和 __getitem__
方法,只要傳入張量即可,它通過第一個維度進行索引。
所以TensorDataset說白了就是將輸入的tensors捆綁在一起,然后 __len__
是任何一個tensor的維度, __getitem__
表示每個tensor取相同的索引,然后將這個結(jié)果組成一個元組,源碼如下,要好好理解它通過第一個維度進行索引的意思(針對tensors里面的每一個tensor而言)。
class TensorDataset(Dataset): def __init__(self,*tensors): assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self,index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
Sampler和RandomSampler
Sampler與Dataset類似,是采樣器的基礎(chǔ)類。
每個采樣器子類必須提供一個 __iter__
方法,提供一種迭代數(shù)據(jù)集元素的索引的方法,以及返回迭代器長度的 __len__
方法。
所以Sampler必然是關(guān)于索引的迭代器,也就是它的輸出是索引。
而RandomSampler與TensorDataset類似,RandomSamper已經(jīng)實現(xiàn)了 __iter__
和 __len__
方法,只需要傳入數(shù)據(jù)集即可。
猜想理解RandomSampler的實現(xiàn)方式,考慮到這個類實現(xiàn)需要傳入Dataset,所以 __len__
就是Dataset的 __len__
,然后 __iter__
就可以隨便搞一個隨機函數(shù)對range(length)隨機即可。
綜合示例
結(jié)合TensorDataset和RandomSampler使用Dataloader
這里即可理解Dataloader這個數(shù)據(jù)加載器其實就是組合數(shù)據(jù)集和采樣器的組合。所以那就是先根據(jù)Sampler隨機拿到一個索引,再用這個索引到Dataset中取tensors里每個tensor對應(yīng)索引的數(shù)據(jù)來組成一個元組。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
pandas中的DataFrame數(shù)據(jù)遍歷解讀
這篇文章主要介紹了pandas中的DataFrame數(shù)據(jù)遍歷解讀,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-12-12Python數(shù)據(jù)分析之堆疊數(shù)組函數(shù)示例總結(jié)
這篇文章主要為大家介紹了Python數(shù)據(jù)分析之堆疊數(shù)組函數(shù)示例總結(jié),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-02-02Python基礎(chǔ)教程之if判斷,while循環(huán),循環(huán)嵌套
這篇文章主要介紹了Python基礎(chǔ)教程之if判斷,while循環(huán),循環(huán)嵌套 的相關(guān)知識,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下2019-04-04Python實現(xiàn)RabbitMQ6種消息模型的示例代碼
這篇文章主要介紹了Python實現(xiàn)RabbitMQ6種消息模型的示例代碼,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03python自動定時任務(wù)schedule庫的使用方法
當(dāng)你需要在 Python 中定期執(zhí)行任務(wù)時,schedule 庫是一個非常實用的工具,它可以幫助你自動化定時任務(wù),本文給大家介紹了python自動定時任務(wù)schedule庫的使用方法,需要的朋友可以參考下2024-02-02利用python微信庫itchat實現(xiàn)微信自動回復(fù)功能
最近發(fā)現(xiàn)了一個特別好玩的Python 微信庫itchat,可以實現(xiàn)自動回復(fù)等多種功能,下面這篇文章主要給大家介紹了利用python微信庫itchat實現(xiàn)微信自動回復(fù)功能的相關(guān)資料,需要的朋友可以參考學(xué)習(xí),下面來一起看看吧。2017-05-05Python 中數(shù)組和數(shù)字相乘時的注意事項說明
這篇文章主要介紹了Python 中數(shù)組和數(shù)字相乘時的注意事項說明,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-05-05詳解pandas數(shù)據(jù)合并與重塑(pd.concat篇)
這篇文章主要介紹了詳解pandas數(shù)據(jù)合并與重塑(pd.concat篇),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07