Pytorch中TensorDataset與DataLoader的使用方式
TensorDataset與DataLoader的使用
TensorDataset
TensorDataset本質(zhì)上與python zip方法類似,對(duì)數(shù)據(jù)進(jìn)行打包整合。
官方文檔說明:
**Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.*
Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.
該類通過每一個(gè) tensor 的第一個(gè)維度進(jìn)行索引。
因此,該類中的 tensor 第一維度必須相等。
import torch from torch.utils.data import TensorDataset # a的形狀為(4*3) a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) # b的第一維與a相同 b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) print(train_data[0:4])
輸出結(jié)果如下:
(tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]]), tensor([1, 2, 3, 4]))
DataLoader
DataLoader本質(zhì)上就是一個(gè)iterable(跟python的內(nèi)置類型list等一樣),并利用多進(jìn)程來加速batch data的處理,使用yield來使用有限的內(nèi)存。
import torch from torch.utils.data import TensorDataset from torch.utils.data import DataLoader a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]]) b = torch.tensor([1,2,3,4]) train_data = TensorDataset(a,b) data = DataLoader(train_data, batch_size=2, shuffle=True) for i, j in enumerate(data): ? ? x, y = j ? ? print(' batch:{0} x:{1} ?y: {2}'.format(i, x, y))
輸出:
batch:0 x:tensor([[1, 1, 1],
[2, 2, 2]]) y: tensor([1, 2])
batch:1 x:tensor([[4, 4, 4],
[3, 3, 3]]) y: tensor([4, 3])
Pytorch Dataset,TensorDataset,Dataloader,Sampler關(guān)系
Dataloader
Dataloader是數(shù)據(jù)加載器,組合數(shù)據(jù)集和采樣器,并在數(shù)據(jù)集上提供單線程或多線程的迭代器。
所以Dataloader的參數(shù)必然需要指定數(shù)據(jù)集Dataset和采樣器Sampler。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
- dataset (Dataset) – 數(shù)據(jù)集。
- batch_size (int, optional) – 每個(gè)batch加載樣本數(shù)。
- shuffle (bool, optional) – True則打亂數(shù)據(jù).
- sampler (Sampler, optional) – 采樣器,如指定則忽略shuffle參數(shù)。
- num_workers (int, optional) – 用多少個(gè)子進(jìn)程加載數(shù)據(jù)。0表示數(shù)據(jù)將在主進(jìn)程中加載
- collate_fn (callable, optional) – 獲取batch數(shù)據(jù)的回調(diào)函數(shù),也就是說可以在這個(gè)函數(shù)中修改batch的形式
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果數(shù)據(jù)集大小不能被batch size整除,則設(shè)置為True后可刪除最后一個(gè)不完整的batch。如果設(shè)為False并且數(shù)據(jù)集的大小不能被batch size整除,則最后一個(gè)batch將更小。
Dataset和TensorDataset
所有其他數(shù)據(jù)集都應(yīng)該進(jìn)行子類化。所有子類應(yīng)該override __len__
和 __getitem__
,前者提供了數(shù)據(jù)集的大小,后者支持整數(shù)索引,范圍從0到len(self)。
TensorDataset是Dataset的子類,已經(jīng)復(fù)寫了 __len__
和 __getitem__
方法,只要傳入張量即可,它通過第一個(gè)維度進(jìn)行索引。
所以TensorDataset說白了就是將輸入的tensors捆綁在一起,然后 __len__
是任何一個(gè)tensor的維度, __getitem__
表示每個(gè)tensor取相同的索引,然后將這個(gè)結(jié)果組成一個(gè)元組,源碼如下,要好好理解它通過第一個(gè)維度進(jìn)行索引的意思(針對(duì)tensors里面的每一個(gè)tensor而言)。
class TensorDataset(Dataset): def __init__(self,*tensors): assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self,index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
Sampler和RandomSampler
Sampler與Dataset類似,是采樣器的基礎(chǔ)類。
每個(gè)采樣器子類必須提供一個(gè) __iter__
方法,提供一種迭代數(shù)據(jù)集元素的索引的方法,以及返回迭代器長度的 __len__
方法。
所以Sampler必然是關(guān)于索引的迭代器,也就是它的輸出是索引。
而RandomSampler與TensorDataset類似,RandomSamper已經(jīng)實(shí)現(xiàn)了 __iter__
和 __len__
方法,只需要傳入數(shù)據(jù)集即可。
猜想理解RandomSampler的實(shí)現(xiàn)方式,考慮到這個(gè)類實(shí)現(xiàn)需要傳入Dataset,所以 __len__
就是Dataset的 __len__
,然后 __iter__
就可以隨便搞一個(gè)隨機(jī)函數(shù)對(duì)range(length)隨機(jī)即可。
綜合示例
結(jié)合TensorDataset和RandomSampler使用Dataloader
這里即可理解Dataloader這個(gè)數(shù)據(jù)加載器其實(shí)就是組合數(shù)據(jù)集和采樣器的組合。所以那就是先根據(jù)Sampler隨機(jī)拿到一個(gè)索引,再用這個(gè)索引到Dataset中取tensors里每個(gè)tensor對(duì)應(yīng)索引的數(shù)據(jù)來組成一個(gè)元組。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
pandas中的DataFrame數(shù)據(jù)遍歷解讀
這篇文章主要介紹了pandas中的DataFrame數(shù)據(jù)遍歷解讀,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-12-12Python數(shù)據(jù)分析之堆疊數(shù)組函數(shù)示例總結(jié)
這篇文章主要為大家介紹了Python數(shù)據(jù)分析之堆疊數(shù)組函數(shù)示例總結(jié),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-02-02Python基礎(chǔ)教程之if判斷,while循環(huán),循環(huán)嵌套
這篇文章主要介紹了Python基礎(chǔ)教程之if判斷,while循環(huán),循環(huán)嵌套 的相關(guān)知識(shí),非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-04-04Python實(shí)現(xiàn)RabbitMQ6種消息模型的示例代碼
這篇文章主要介紹了Python實(shí)現(xiàn)RabbitMQ6種消息模型的示例代碼,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03python自動(dòng)定時(shí)任務(wù)schedule庫的使用方法
當(dāng)你需要在 Python 中定期執(zhí)行任務(wù)時(shí),schedule 庫是一個(gè)非常實(shí)用的工具,它可以幫助你自動(dòng)化定時(shí)任務(wù),本文給大家介紹了python自動(dòng)定時(shí)任務(wù)schedule庫的使用方法,需要的朋友可以參考下2024-02-02利用python微信庫itchat實(shí)現(xiàn)微信自動(dòng)回復(fù)功能
最近發(fā)現(xiàn)了一個(gè)特別好玩的Python 微信庫itchat,可以實(shí)現(xiàn)自動(dòng)回復(fù)等多種功能,下面這篇文章主要給大家介紹了利用python微信庫itchat實(shí)現(xiàn)微信自動(dòng)回復(fù)功能的相關(guān)資料,需要的朋友可以參考學(xué)習(xí),下面來一起看看吧。2017-05-05Python?中?Kwargs?解析的最佳實(shí)踐教程
這篇文章主要介紹了Python中Kwargs解析的最佳實(shí)踐,使用?kwargs,我們可以編寫帶有任意數(shù)量關(guān)鍵字參數(shù)的函數(shù),當(dāng)我們想為函數(shù)提供靈活的接口時(shí),這會(huì)很有用,需要的朋友可以參考下2023-06-06Python 中數(shù)組和數(shù)字相乘時(shí)的注意事項(xiàng)說明
這篇文章主要介紹了Python 中數(shù)組和數(shù)字相乘時(shí)的注意事項(xiàng)說明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-05-05詳解pandas數(shù)據(jù)合并與重塑(pd.concat篇)
這篇文章主要介紹了詳解pandas數(shù)據(jù)合并與重塑(pd.concat篇),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07