Pytorch數(shù)據(jù)讀取與預(yù)處理該如何實(shí)現(xiàn)
在煉丹時,數(shù)據(jù)的讀取與預(yù)處理是關(guān)鍵一步。不同的模型所需要的數(shù)據(jù)以及預(yù)處理方式各不相同,如果每個輪子都我們自己寫的話,是很浪費(fèi)時間和精力的。Pytorch幫我們實(shí)現(xiàn)了方便的數(shù)據(jù)讀取與預(yù)處理方法,下面記錄兩個DEMO,便于加快以后的代碼效率。
根據(jù)數(shù)據(jù)是否一次性讀取完,將DEMO分為:
1、串行式讀取。也就是一次性讀取完所有需要的數(shù)據(jù)到內(nèi)存,模型訓(xùn)練時不會再訪問外存。通常用在內(nèi)存足夠的情況下使用,速度更快。
2、并行式讀取。也就是邊訓(xùn)練邊讀取數(shù)據(jù)。通常用在內(nèi)存不夠的情況下使用,會占用計算資源,如果分配的好的話,幾乎不損失速度。
Pytorch官方的數(shù)據(jù)提取方式盡管方便編碼,但由于它提取數(shù)據(jù)方式比較死板,會浪費(fèi)資源,下面對其進(jìn)行分析。
1 串行式讀取
1.1 DEMO代碼
import torch from torch.utils.data import Dataset,DataLoader class MyDataSet(Dataset):# ————1———— def __init__(self): self.data = torch.tensor(range(10)).reshape([5,2]) self.label = torch.tensor(range(5)) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) my_data_set = MyDataSet()# ————2———— my_data_loader = DataLoader( dataset=my_data_set, # ————3———— batch_size=2, # ————4———— shuffle=True, # ————5———— sampler=None, # ————6———— batch_sampler=None, # ————7———— num_workers=0 , # ————8———— collate_fn=None, # ————9———— pin_memory=True, # ————10———— drop_last=True # ————11———— ) for i in my_data_loader: # ————12———— print(i)
注釋處解釋如下:
1、重寫數(shù)據(jù)集類,用于保存數(shù)據(jù)。除了 __init__() 外,必須實(shí)現(xiàn) __getitem__() 和 __len__() 兩個方法。前一個方法用于輸出索引對應(yīng)的數(shù)據(jù)。后一個方法用于獲取數(shù)據(jù)集的長度。
2~5、 2準(zhǔn)備好數(shù)據(jù)集后,傳入DataLoader來迭代生成數(shù)據(jù)。前三個參數(shù)分別是傳入的數(shù)據(jù)集對象、每次獲取的批量大小、是否打亂數(shù)據(jù)集輸出。
6、采樣器,如果定義這個,shuffle只能設(shè)置為False。所謂采樣器就是用于生成數(shù)據(jù)索引的可迭代對象,比如列表。因此,定義了采樣器,采樣都按它來,shuffle再打亂就沒意義了。
7、批量采樣器,如果定義這個,batch_size、shuffle、sampler、drop_last都不能定義。實(shí)際上,如果沒有特殊的數(shù)據(jù)生成順序的要求,采樣器并沒有必要定義。torch.utils.data 中的各種 Sampler 就是采樣器類,如果需要,可以使用它們來定義。
8、用于生成數(shù)據(jù)的子進(jìn)程數(shù)。默認(rèn)為0,不并行。
9、拼接多個樣本的方法,默認(rèn)是將每個batch的數(shù)據(jù)在第一維上進(jìn)行拼接。這樣可能說不清楚,并且由于這里可以探究一下獲取數(shù)據(jù)的速度,后面再詳細(xì)說明。
10、是否使用鎖頁內(nèi)存。用的話會更快,內(nèi)存不充足最好別用。
11、是否把最后小于batch的數(shù)據(jù)丟掉。
12、迭代獲取數(shù)據(jù)并輸出。
1.2 速度探索
首先看一下DEMO的輸出:
輸出了兩個batch的數(shù)據(jù),每組數(shù)據(jù)中data和label都正確排列,符合我們的預(yù)期。那么DataLoader是怎么把數(shù)據(jù)整合起來的呢?首先,我們把collate_fn定義為直接映射(不用它默認(rèn)的方法),來查看看每次DataLoader從MyDataSet中讀取了什么,將上面部分代碼修改如下:
my_data_loader = DataLoader( dataset=my_data_set, batch_size=2, shuffle=True, sampler=None, batch_sampler=None, num_workers=0 , collate_fn=lambda x:x, #修改處 pin_memory=True, drop_last=True )
結(jié)果如下:
輸出還是兩個batch,然而每個batch中,單個的data和label是在一個list中的。似乎可以看出,DataLoader是一個一個讀取MyDataSet中的數(shù)據(jù)的,然后再進(jìn)行相應(yīng)數(shù)據(jù)的拼接。為了驗(yàn)證這點(diǎn),代碼修改如下:
import torch from torch.utils.data import Dataset,DataLoader class MyDataSet(Dataset): def __init__(self): self.data = torch.tensor(range(10)).reshape([5,2]) self.label = torch.tensor(range(5)) def __getitem__(self, index): print(index) #修改處2 return self.data[index], self.label[index] def __len__(self): return len(self.data) my_data_set = MyDataSet() my_data_loader = DataLoader( dataset=my_data_set, batch_size=2, shuffle=True, sampler=None, batch_sampler=None, num_workers=0 , collate_fn=lambda x:x, #修改處1 pin_memory=True, drop_last=True ) for i in my_data_loader: print(i)
輸出如下:
驗(yàn)證了前面的猜想,的確是一個一個讀取的。如果數(shù)據(jù)集定義的不是格式化的數(shù)據(jù),那還好,但是我這里定義的是tensor,是可以直接通過列表來索引對應(yīng)的tensor的。因此,DataLoader的操作比直接索引多了拼接這一步,肯定是會慢很多的。一兩次的讀取還好,但在訓(xùn)練中,大量的讀取累加起來,就會浪費(fèi)很多時間了。
自定義一個DataLoader可以證明這一點(diǎn),代碼如下:
import torch from torch.utils.data import Dataset,DataLoader from time import time class MyDataSet(Dataset): def __init__(self): self.data = torch.tensor(range(100000)).reshape([50000,2]) self.label = torch.tensor(range(50000)) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return len(self.data) # 自定義DataLoader class MyDataLoader(): def __init__(self, dataset,batch_size): self.dataset = dataset self.batch_size = batch_size def __iter__(self): self.now = 0 self.shuffle_i = np.array(range(self.dataset.__len__())) np.random.shuffle(self.shuffle_i) return self def __next__(self): self.now += self.batch_size if self.now <= len(self.shuffle_i): indexes = self.shuffle_i[self.now-self.batch_size:self.now] return self.dataset.__getitem__(indexes) else: raise StopIteration # 使用官方DataLoader my_data_set = MyDataSet() my_data_loader = DataLoader( dataset=my_data_set, batch_size=256, shuffle=True, sampler=None, batch_sampler=None, num_workers=0 , collate_fn=None, pin_memory=True, drop_last=True ) start_t = time() for t in range(10): for i in my_data_loader: pass print("官方:", time() - start_t) #自定義DataLoader my_data_set = MyDataSet() my_data_loader = MyDataLoader(my_data_set,256) start_t = time() for t in range(10): for i in my_data_loader: pass print("自定義:", time() - start_t)
運(yùn)行結(jié)果如下:
以上使用batch大小為256,僅各讀取10 epoch的數(shù)據(jù),都有30多倍的時間上的差距,更大的batch差距會更明顯。另外,這里用于測試的每個數(shù)據(jù)只有兩個浮點(diǎn)數(shù),如果是圖像,所需的時間可能會增加幾百倍。因此,如果數(shù)據(jù)量和batch都比較大,并且數(shù)據(jù)是格式化的,最好自己寫數(shù)據(jù)生成器。
2 并行式讀取
2.1 DEMO代碼
import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import ImageFolder path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128' my_data_set = ImageFolder( #————1———— root = path, #————2———— transform = transforms.Compose([ #————3———— transforms.ToTensor(), transforms.CenterCrop(64) ]), loader = plt.imread #————4———— ) my_data_loader = DataLoader( dataset=my_data_set, batch_size=128, shuffle=True, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=True, drop_last=True ) for i in my_data_loader: print(i)
注釋處解釋如下:
1/2、ImageFolder類繼承自DataSet類,因此可以按索引讀取圖像。路徑必須包含文件夾,ImageFolder會給每個文件夾中的圖像添加索引,并且每張圖像會給予其所在文件夾的標(biāo)簽。舉個例子,代碼中my_data_set[0] 輸出的是圖像對象和它對應(yīng)的標(biāo)簽組成的列表。
3、圖像到格式化數(shù)據(jù)的轉(zhuǎn)換組合。更多的轉(zhuǎn)換方法可以看 transform 模塊。
4、圖像法的讀取方式,默認(rèn)是PIL.Image.open(),但我發(fā)現(xiàn)plt.imread()更快一些。
由于是邊訓(xùn)練邊讀取,transform會占用很多時間,因此可以先將圖像轉(zhuǎn)換為需要的形式存入外存再讀取,從而避免重復(fù)操作。
其中transform.ToTensor()會把正常讀取的圖像轉(zhuǎn)換為torch.tensor,并且像素值會映射至[0,1][0,1]。由于plt.imread()讀取png圖像時,像素值在[0,1][0,1],而讀取jpg圖像時,像素值卻在[0,255][0,255],因此使用transform.ToTensor()能將圖像像素區(qū)間統(tǒng)一化。
以上就是Pytorch數(shù)據(jù)讀取與預(yù)處理該如何實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于Pytorch數(shù)據(jù)讀取與預(yù)處理的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
使用Python簡單的實(shí)現(xiàn)樹莓派的WEB控制
這篇文章主要介紹了使用Python簡單的實(shí)現(xiàn)樹莓派的WEB控制的相關(guān)資料,需要的朋友可以參考下2016-02-02PyTorch一小時掌握之a(chǎn)utograd機(jī)制篇
這篇文章主要介紹了PyTorch一小時掌握之a(chǎn)utograd機(jī)制篇,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-09-09Python操作SQLite/MySQL/LMDB數(shù)據(jù)庫的方法
這篇文章主要介紹了Python操作SQLite/MySQL/LMDB數(shù)據(jù)庫的方法,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下2019-11-11Python計算一個給定時間點(diǎn)前一個月和后一個月第一天的方法
這篇文章主要介紹了Python計算一個給定時間點(diǎn)前一個月和后一個月第一天的方法,涉及Python使用datetime模塊計算日期時間的相關(guān)操作技巧,需要的朋友可以參考下2018-05-05python?selenium實(shí)現(xiàn)登錄豆瓣示例詳解
大家好,本篇文章主要講的是python?selenium登錄豆瓣示例詳解,感興趣的同學(xué)趕快來看一看吧,對你有幫助的話記得收藏一下2022-01-01