欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關(guān)系

 更新時間:2020年07月03日 09:43:43   作者:marsggbo  
這篇文章主要介紹了一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關(guān)系,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧

以下內(nèi)容都是針對Pytorch 1.0-1.1介紹。

很多文章都是從Dataset等對象自下往上進行介紹,但是對于初學(xué)者而言,其實這并不好理解,因為有的時候會不自覺地陷入到一些細枝末節(jié)中去,而不能把握重點,所以本文將會自上而下地對Pytorch數(shù)據(jù)讀取方法進行介紹。

自上而下理解三者關(guān)系

首先我們看一下DataLoader.next的源代碼長什么樣,為方便理解我只選取了num_works為0的情況(num_works簡單理解就是能夠并行化地讀取數(shù)據(jù))。

class DataLoader(object):
	...
	
 def __next__(self):
  if self.num_workers == 0: 
   indices = next(self.sample_iter) # Sampler
   batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
   if self.pin_memory:
    batch = _utils.pin_memory.pin_memory_batch(batch)
   return batch

在閱讀上面代碼前,我們可以假設(shè)我們的數(shù)據(jù)是一組圖像,每一張圖像對應(yīng)一個index,那么如果我們要讀取數(shù)據(jù)就只需要對應(yīng)的index即可,即上面代碼中的indices,而選取index的方式有多種,有按順序的,也有亂序的,所以這個工作需要Sampler完成,現(xiàn)在你不需要具體的細節(jié),后面會介紹,你只需要知道DataLoader和Sampler在這里產(chǎn)生關(guān)系。

那么Dataset和DataLoader在什么時候產(chǎn)生關(guān)系呢?沒錯就是下面一行。我們已經(jīng)拿到了indices,那么下一步我們只需要根據(jù)index對數(shù)據(jù)進行讀取即可了。

再下面的if語句的作用簡單理解就是,如果pin_memory=True,那么Pytorch會采取一系列操作把數(shù)據(jù)拷貝到GPU,總之就是為了加速。

綜上可以知道DataLoader,Sampler和Dataset三者關(guān)系如下:

在閱讀后文的過程中,你始終需要將上面的關(guān)系記在心里,這樣能幫助你更好地理解。

Sampler

參數(shù)傳遞

要更加細致地理解Sampler原理,我們需要先閱讀一下DataLoader 的源代碼,如下:

class DataLoader(object):
 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
     batch_sampler=None, num_workers=0, collate_fn=default_collate,
     pin_memory=False, drop_last=False, timeout=0,
     worker_init_fn=None)

可以看到初始化參數(shù)里有兩種sampler:samplerbatch_sampler,都默認為None。前者的作用是生成一系列的index,而batch_sampler則是將sampler生成的indices打包分組,得到一個又一個batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分組。

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

Pytorch中已經(jīng)實現(xiàn)的Sampler有如下幾種:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化參數(shù)之間存在互斥關(guān)系,這個你可以通過閱讀源碼更深地理解,這里只做總結(jié):

  • 如果你自定義了batch_sampler,那么這些參數(shù)都必須使用默認值:batch_size, shuffle,sampler,drop_last.
  • 如果你自定義了sampler,那么shuffle需要設(shè)置為False
  • 如果sampler和batch_sampler都為None,那么batch_sampler使用Pytorch已經(jīng)實現(xiàn)好的BatchSampler,而sampler分兩種情況:
    • 若shuffle=True,則sampler=RandomSampler(dataset)
    • 若shuffle=False,則sampler=SequentialSampler(dataset)

如何自定義Sampler和BatchSampler?

仔細查看源代碼其實可以發(fā)現(xiàn),所有采樣器其實都繼承自同一個父類,即Sampler,其代碼定義如下:

class Sampler(object):
 r"""Base class for all Samplers.
 Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
 way to iterate over indices of dataset elements, and a :meth:`__len__` method
 that returns the length of the returned iterators.
 .. note:: The :meth:`__len__` method isn't strictly required by
    :class:`~torch.utils.data.DataLoader`, but is expected in any
    calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
 """

 def __init__(self, data_source):
  pass

 def __iter__(self):
  raise NotImplementedError
		
 def __len__(self):
  return len(self.data_source)

所以你要做的就是定義好__iter__(self)函數(shù),不過要注意的是該函數(shù)的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))。

另外BatchSampler與其他Sampler的主要區(qū)別是它需要將Sampler作為參數(shù)進行打包,進而每次迭代返回以batch size為大小的index列表。也就是說在后面的讀取數(shù)據(jù)過程中使用的都是batch sampler。

Dataset

Dataset定義方式如下:

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三個方法是最基本的,其中__getitem__是最主要的方法,它規(guī)定了如何讀取數(shù)據(jù)。但是它又不同于一般的方法,因為它是python built-in方法,其主要作用是能讓該類可以像list一樣通過索引值對數(shù)據(jù)進行訪問。假如你定義好了一個dataset,那么你可以直接通過dataset[0]來訪問第一個數(shù)據(jù)。在此之前我一直沒弄清楚__getitem__是什么作用,所以一直不知道該怎么進入到這個函數(shù)進行調(diào)試。現(xiàn)在如果你想對__getitem__方法進行調(diào)試,你可以寫一個for循環(huán)遍歷dataset來進行調(diào)試了,而不用構(gòu)建dataloader等一大堆東西了,建議學(xué)會使用ipdb這個庫,非常實用?。。∫院笥袝r間再寫一篇ipdb的使用教程。另外,其實我們通過最前面的Dataloader的__next__函數(shù)可以看到DataLoader對數(shù)據(jù)的讀取其實就是用了for循環(huán)來遍歷數(shù)據(jù),不用往上翻了,我直接復(fù)制了一遍,如下:

class DataLoader(object): 
 ... 
  
 def __next__(self): 
  if self.num_workers == 0: 
   indices = next(self.sample_iter) 
   batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
   if self.pin_memory: 
    batch = _utils.pin_memory.pin_memory_batch(batch) 
   return batch

我們仔細看可以發(fā)現(xiàn),前面還有一個self.collate_fn方法,這個是干嘛用的呢?在介紹前我們需要知道每個參數(shù)的意義:

  • indices: 表示每一個iteration,sampler返回的indices,即一個batch size大小的索引列表
  • self.dataset[i]: 前面已經(jīng)介紹了,這里就是對第i個數(shù)據(jù)進行讀取操作,一般來說self.dataset[i]=(img, label)

看到這不難猜出collate_fn的作用就是將一個batch的數(shù)據(jù)進行合并操作。默認的collate_fn是將img和label分別合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默認的collate_fn方法,但是如果你每次讀取的數(shù)據(jù)有img, box, label等等,那么你就需要自定義collate_fn來將對應(yīng)的數(shù)據(jù)合并成一個batch數(shù)據(jù),這樣方便后續(xù)的訓(xùn)練步驟。

到此這篇關(guān)于一文弄懂Pytorch的DataLoader, DataSet, Sampler之間的關(guān)系的文章就介紹到這了,更多相關(guān)Pytorch DataLoader DataSet Sampler內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Flask-Vue前后端分離的全過程講解

    Flask-Vue前后端分離的全過程講解

    這篇文章主要介紹了Flask-Vue前后端分離的全過程,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-07-07
  • Python Word文件自動化實戰(zhàn)之簡歷篩選

    Python Word文件自動化實戰(zhàn)之簡歷篩選

    本文將利用Python自動化做一個具有實操性的小練習(xí),即通過讀取簡歷來篩選出符合招聘條件的簡歷。文中的示例代碼講解詳細,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下
    2022-05-05
  • Python 轉(zhuǎn)換文本編碼實現(xiàn)解析

    Python 轉(zhuǎn)換文本編碼實現(xiàn)解析

    這篇文章主要介紹了Python 轉(zhuǎn)換文本編碼實現(xiàn)解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值
    2019-08-08
  • Python 繪圖庫 Matplotlib 入門教程

    Python 繪圖庫 Matplotlib 入門教程

    Matplotlib是一個Python語言的2D繪圖庫,它支持各種平臺,并且功能強大,能夠輕易繪制出各種專業(yè)的圖像。本文是對Python 繪圖庫 Matplotlib 入門教程,感興趣的朋友跟隨腳本之家小編一起學(xué)習(xí)吧
    2018-04-04
  • Django基于Models定制Admin后臺實現(xiàn)過程解析

    Django基于Models定制Admin后臺實現(xiàn)過程解析

    這篇文章主要介紹了Django基于Models定制Admin后臺實現(xiàn)過程解析,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2020-11-11
  • PyTorch中常用的激活函數(shù)的方法示例

    PyTorch中常用的激活函數(shù)的方法示例

    這篇文章主要介紹了PyTorch中常用的激活函數(shù)的方法示例,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-08-08
  • OpenCV圖像輪廓的繪制方法

    OpenCV圖像輪廓的繪制方法

    這篇文章主要為大家詳細介紹了OpenCV圖像輪廓的繪制方法,以及測試幾何圖形、花朵圖形輪廓,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2021-08-08
  • 用Python編寫簡單的微博爬蟲

    用Python編寫簡單的微博爬蟲

    這篇文章主要介紹了如何利用Python編寫一個簡單的微博爬蟲,感興趣的小伙伴們可以參考一下
    2016-03-03
  • 通過Python編程將CSV文件導(dǎo)出為PDF文件的方法

    通過Python編程將CSV文件導(dǎo)出為PDF文件的方法

    CSV文件通常用于存儲大量的數(shù)據(jù),而PDF文件則是一種通用的文檔格式,便于與他人共享和打印,將CSV文件轉(zhuǎn)換成PDF文件可以幫助我們更好地管理和展示數(shù)據(jù),本文將介紹如何通過Python編程將CSV文件導(dǎo)出為PDF文件,需要的朋友可以參考下
    2024-06-06
  • Pygame Display顯示模塊的使用方法

    Pygame Display顯示模塊的使用方法

    本文主要介紹了Pygame Display顯示模塊的使用方法,文中通過示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2021-11-11

最新評論