pytorch中dataloader 的sampler 參數(shù)詳解
1. dataloader() 初始化函數(shù)
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None):
其中幾個(gè)常用的參數(shù):
- dataset 數(shù)據(jù)集,map-style and iterable-style 可以用index取值的對(duì)象、
- batch_size 大小
- shuffle 取batch是否隨機(jī)取, 默認(rèn)為False
- sampler 定義取batch的方法,是一個(gè)迭代器, 每次生成一個(gè)key 用于讀取dataset中的值
- batch_sampler 也是一個(gè)迭代器, 每次生次一個(gè)batch_size的key
- num_workers 參與工作的線程數(shù)collate_fn 對(duì)取出的batch進(jìn)行處理
- drop_last 對(duì)最后不足batchsize的數(shù)據(jù)的處理方法
下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個(gè)常用參數(shù)之間的關(guān)系
2. shuffle 與sample 之間的關(guān)系
當(dāng)我們sampler有輸入時(shí),shuffle的值就沒(méi)有意義,
if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset)
當(dāng)dataset類型是map style時(shí), shuffle其實(shí)就是改變sampler的取值
- shuffle為默認(rèn)值 False時(shí),sampler是SequentialSampler,就是按順序取樣,
- shuffle為T(mén)rue時(shí),sampler是RandomSampler, 就是按隨機(jī)取樣
3. sample 的定義方法
3.1 sampler 參數(shù)的使用
sampler 是用來(lái)定義取batch方法的一個(gè)函數(shù)或者類,返回的是一個(gè)迭代器。
我們可以看下自帶的RandomSampler類中最重要的iter函數(shù)
def __iter__(self): n = len(self.data_source) # dataset的長(zhǎng)度, 按順序索引 if self.replacement:# 對(duì)應(yīng)的replace參數(shù) return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist())
可以看出,其實(shí)就是生成索引,然后隨機(jī)的取值, 然后再迭代。
其實(shí)還有一些細(xì)節(jié)需要注意理解:
比如__len__函數(shù),包括DataLoader的len和sample的len, 兩者區(qū)別, 這部分代碼比較簡(jiǎn)單,可以自行閱讀,其實(shí)參考著RandomSampler寫(xiě)也不會(huì)出現(xiàn)問(wèn)題。
比如,迭代器和生成器的使用, 以及區(qū)別
if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler
BatchSampler的生成過(guò)程:
# 略去類的初始化 def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
就是按batch_size從sampler中讀取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關(guān)系
- 如果batch_sampler沒(méi)有定義的話且batch_size有定義, 會(huì)根據(jù)sampler, batch_size, drop_last生成一個(gè)batch_sampler
- 自帶的注釋中對(duì)batch_sampler有一句話: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
- 意思就是b
- atch_sampler 與這些參數(shù)沖突 ,即 如果你定義了batch_sampler, 其他參數(shù)都不需要有
4. batch 生成過(guò)程
每個(gè)batch都是由迭代器產(chǎn)生的:
# DataLoader中iter的部分 def __iter__(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) # 再看調(diào)用的另一個(gè)類 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() data = self._dataset_fetcher.fetch(index) if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data
到此這篇關(guān)于pytorch中dataloader 的sampler 參數(shù)詳解的文章就介紹到這了,更多相關(guān)pytorch sampler 內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python3數(shù)據(jù)庫(kù)操作包pymysql的操作方法
這篇文章主要介紹了Python3數(shù)據(jù)庫(kù)操作包pymysql的操作方法,文章通過(guò)實(shí)例代碼相結(jié)合給大家介紹的非常詳細(xì),需要的朋友可以參考下2018-07-07對(duì)python3 Serial 串口助手的接收讀取數(shù)據(jù)方法詳解
今天小編就為大家分享一篇對(duì)python3 Serial 串口助手的接收讀取數(shù)據(jù)方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-06-06解讀pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數(shù)
這篇文章主要介紹了pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數(shù)的用法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09利用Anaconda簡(jiǎn)單安裝scrapy框架的方法
今天小編就為大家分享一篇利用Anaconda簡(jiǎn)單安裝scrapy框架的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-06-06Django 解決由save方法引發(fā)的錯(cuò)誤
這篇文章主要介紹了Django 解決由save方法引發(fā)的錯(cuò)誤,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05Python FastAPI 多參數(shù)傳遞的示例詳解
這篇文章主要介紹了Python FastAPI 多參數(shù)傳遞,FastAPI通過(guò)模板來(lái)匹配URL中的參數(shù)列表,大概分為三類方式傳遞參數(shù),每種方式結(jié)合示例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下2022-12-12python爬蟲(chóng)模擬瀏覽器訪問(wèn)-User-Agent過(guò)程解析
這篇文章主要介紹了python爬蟲(chóng)模擬瀏覽器訪問(wèn)-User-Agent過(guò)程解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-12-12對(duì)python中Matplotlib的坐標(biāo)軸的坐標(biāo)區(qū)間的設(shè)定實(shí)例講解
今天小編就為大家分享一篇對(duì)python中Matplotlib的坐標(biāo)軸的坐標(biāo)區(qū)間的設(shè)定實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-05-05