pytorch中dataloader 的sampler 參數(shù)詳解
1. dataloader() 初始化函數(shù)
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None):
其中幾個(gè)常用的參數(shù):
- dataset 數(shù)據(jù)集,map-style and iterable-style 可以用index取值的對象、
- batch_size 大小
- shuffle 取batch是否隨機(jī)取, 默認(rèn)為False
- sampler 定義取batch的方法,是一個(gè)迭代器, 每次生成一個(gè)key 用于讀取dataset中的值
- batch_sampler 也是一個(gè)迭代器, 每次生次一個(gè)batch_size的key
- num_workers 參與工作的線程數(shù)collate_fn 對取出的batch進(jìn)行處理
- drop_last 對最后不足batchsize的數(shù)據(jù)的處理方法
下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個(gè)常用參數(shù)之間的關(guān)系
2. shuffle 與sample 之間的關(guān)系
當(dāng)我們sampler有輸入時(shí),shuffle的值就沒有意義,
if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: # See NOTE [ Custom Samplers and IterableDataset ] sampler = _InfiniteConstantSampler() else: # map-style if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset)
當(dāng)dataset類型是map style時(shí), shuffle其實(shí)就是改變sampler的取值
- shuffle為默認(rèn)值 False時(shí),sampler是SequentialSampler,就是按順序取樣,
- shuffle為True時(shí),sampler是RandomSampler, 就是按隨機(jī)取樣
3. sample 的定義方法
3.1 sampler 參數(shù)的使用
sampler 是用來定義取batch方法的一個(gè)函數(shù)或者類,返回的是一個(gè)迭代器。
我們可以看下自帶的RandomSampler類中最重要的iter函數(shù)
def __iter__(self): n = len(self.data_source) # dataset的長度, 按順序索引 if self.replacement:# 對應(yīng)的replace參數(shù) return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) return iter(torch.randperm(n).tolist())
可以看出,其實(shí)就是生成索引,然后隨機(jī)的取值, 然后再迭代。
其實(shí)還有一些細(xì)節(jié)需要注意理解:
比如__len__函數(shù),包括DataLoader的len和sample的len, 兩者區(qū)別, 這部分代碼比較簡單,可以自行閱讀,其實(shí)參考著RandomSampler寫也不會出現(xiàn)問題。
比如,迭代器和生成器的使用, 以及區(qū)別
if batch_size is not None and batch_sampler is None: # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler
BatchSampler的生成過程:
# 略去類的初始化 def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch
就是按batch_size從sampler中讀取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關(guān)系
- 如果batch_sampler沒有定義的話且batch_size有定義, 會根據(jù)sampler, batch_size, drop_last生成一個(gè)batch_sampler
- 自帶的注釋中對batch_sampler有一句話: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
- 意思就是b
- atch_sampler 與這些參數(shù)沖突 ,即 如果你定義了batch_sampler, 其他參數(shù)都不需要有
4. batch 生成過程
每個(gè)batch都是由迭代器產(chǎn)生的:
# DataLoader中iter的部分 def __iter__(self): if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: return _MultiProcessingDataLoaderIter(self) # 再看調(diào)用的另一個(gè)類 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_SingleProcessDataLoaderIter, self).__init__(loader) assert self._timeout == 0 assert self._num_workers == 0 self._dataset_fetcher = _DatasetKind.create_fetcher( self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def __next__(self): index = self._next_index() data = self._dataset_fetcher.fetch(index) if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data
到此這篇關(guān)于pytorch中dataloader 的sampler 參數(shù)詳解的文章就介紹到這了,更多相關(guān)pytorch sampler 內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python3數(shù)據(jù)庫操作包pymysql的操作方法
這篇文章主要介紹了Python3數(shù)據(jù)庫操作包pymysql的操作方法,文章通過實(shí)例代碼相結(jié)合給大家介紹的非常詳細(xì),需要的朋友可以參考下2018-07-07對python3 Serial 串口助手的接收讀取數(shù)據(jù)方法詳解
今天小編就為大家分享一篇對python3 Serial 串口助手的接收讀取數(shù)據(jù)方法詳解,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06解讀pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數(shù)
這篇文章主要介紹了pandas交叉表與透視表pd.crosstab()和pd.pivot_table()函數(shù)的用法,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09Django 解決由save方法引發(fā)的錯(cuò)誤
這篇文章主要介紹了Django 解決由save方法引發(fā)的錯(cuò)誤,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05Python FastAPI 多參數(shù)傳遞的示例詳解
這篇文章主要介紹了Python FastAPI 多參數(shù)傳遞,FastAPI通過模板來匹配URL中的參數(shù)列表,大概分為三類方式傳遞參數(shù),每種方式結(jié)合示例代碼給大家介紹的非常詳細(xì),需要的朋友可以參考下2022-12-12python爬蟲模擬瀏覽器訪問-User-Agent過程解析
這篇文章主要介紹了python爬蟲模擬瀏覽器訪問-User-Agent過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-12-12對python中Matplotlib的坐標(biāo)軸的坐標(biāo)區(qū)間的設(shè)定實(shí)例講解
今天小編就為大家分享一篇對python中Matplotlib的坐標(biāo)軸的坐標(biāo)區(qū)間的設(shè)定實(shí)例講解,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05