pytorch中DataLoader()過程中遇到的一些問題
如下所示:
RuntimeError: stack expects each tensor to be equal size, but got [3, 60, 32] at entry 0 and [3, 54, 32] at entry 2
train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize((224)) ###
原因是
transforms.Resize() 的參數(shù)設(shè)置問題,改為如下設(shè)置就可以了
train_dataset = datasets.ImageFolder( traindir, transforms.Compose([ transforms.Resize((224,224)),
同理,val_dataset中也調(diào)整為transforms.Resize((224,224))。
補充:pytorch之dataloader深入剖析
- dataloader本質(zhì)是一個可迭代對象,使用iter()訪問,不能使用next()訪問;
- 使用iter(dataloader)返回的是一個迭代器,然后可以使用next訪問;
- 也可以使用`for inputs, labels in dataloaders`進行可迭代對象的訪問;
- 一般我們實現(xiàn)一個datasets對象,傳入到dataloader中;然后內(nèi)部使用yeild返回每一次batch的數(shù)據(jù);
① DataLoader本質(zhì)上就是一個iterable(跟python的內(nèi)置類型list等一樣),并利用多進程來加速batch data的處理,使用yield來使用有限的內(nèi)存
② Queue的特點
當隊列里面沒有數(shù)據(jù)時: queue.get() 會阻塞, 阻塞的時候,其它進程/線程如果有queue.put() 操作,本線程/進程會被通知,然后就可以 get 成功。
當數(shù)據(jù)滿了: queue.put() 會阻塞
③ DataLoader是一個高效,簡潔,直觀的網(wǎng)絡(luò)輸入數(shù)據(jù)結(jié)構(gòu),便于使用和擴展
輸入數(shù)據(jù)PipeLine
pytorch 的數(shù)據(jù)加載到模型的操作順序是這樣的:
① 創(chuàng)建一個 Dataset 對象
② 創(chuàng)建一個 DataLoader 對象
③ 循環(huán)這個 DataLoader 對象,將img, label加載到模型中進行訓(xùn)練
dataset = MyDataset() dataloader = DataLoader(dataset) num_epoches = 100 for epoch in range(num_epoches): for img, label in dataloader: ....
所以,作為直接對數(shù)據(jù)進入模型中的關(guān)鍵一步, DataLoader非常重要。
首先簡單介紹一下DataLoader,它是PyTorch中數(shù)據(jù)讀取的一個重要接口,該接口定義在dataloader.py中,只要是用PyTorch來訓(xùn)練模型基本都會用到該接口(除非用戶重寫…),該接口的目的:將自定義的Dataset根據(jù)batch size大小、是否shuffle等封裝成一個Batch Size大小的Tensor,用于后面的訓(xùn)練。
官方對DataLoader的說明是:“數(shù)據(jù)加載由數(shù)據(jù)集和采樣器組成,基于python的單、多進程的iterators來處理數(shù)據(jù)?!标P(guān)于iterator和iterable的區(qū)別和概念請自行查閱,在實現(xiàn)中的差別就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
1.DataLoader
先介紹一下DataLoader(object)的參數(shù):
dataset(Dataset)
: 傳入的數(shù)據(jù)集
batch_size(int, optional)
: 每個batch有多少個樣本
shuffle(bool, optional)
: 在每個epoch開始的時候,對數(shù)據(jù)進行重新排序
sampler(Sampler, optional)
: 自定義從數(shù)據(jù)集中取樣本的策略,如果指定這個參數(shù),那么shuffle必須為False
batch_sampler(Sampler, optional)
: 與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional)
: 這個參數(shù)決定了有幾個進程來處理data loading。0意味著所有的數(shù)據(jù)都會被load進主進程。(默認為0)
collate_fn (callable, optional)
: 將一個list的sample組成一個mini-batch的函數(shù)
pin_memory (bool, optional)
: 如果設(shè)置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.
drop_last (bool, optional)
: 如果設(shè)置為True:這個是對最后的未完成的batch來說的,比如你的batch_size設(shè)置為64,而一個epoch只有100個樣本,那么訓(xùn)練的時候后面的36個就被扔掉了…
如果為False(默認),那么會繼續(xù)正常執(zhí)行,只是最后的batch_size會小一點。
timeout(numeric, optional)
: 如果是正數(shù),表明等待從worker進程中收集一個batch等待的時間,若超出設(shè)定的時間還沒有收集到,那就不收集這個內(nèi)容了。這個numeric應(yīng)總是大于等于0。默認為0
worker_init_fn (callable, optional)
: 每個worker初始化函數(shù) If not None, this will be called on each
worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
- 首先dataloader初始化時得到datasets的采樣list
class DataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) timeout (numeric, optional): if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None) .. note:: By default, each worker will have its PyTorch seed set to ``base_seed + worker_id``, where ``base_seed`` is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See :ref:`dataloader-workers-random-seed` section in FAQ.) You may use ``torch.initial_seed()`` to access the PyTorch seed for each worker in :attr:`worker_init_fn`, and use it to set other seeds before data loading. .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. """ __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last self.timeout = timeout self.worker_init_fn = worker_init_fn if timeout < 0: raise ValueError('timeout option should be non-negative') if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler option is mutually exclusive ' 'with batch_size, shuffle, sampler, and ' 'drop_last') self.batch_size = None self.drop_last = None if sampler is not None and shuffle: raise ValueError('sampler option is mutually exclusive with ' 'shuffle') if self.num_workers < 0: raise ValueError('num_workers option cannot be negative; ' 'use num_workers=0 to disable multiprocessing.') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) //將list打亂 else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler self.__initialized = True def __setattr__(self, attr, val): if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): raise ValueError('{} attribute should not be set after {} is ' 'initialized'.format(attr, self.__class__.__name__)) super(DataLoader, self).__setattr__(attr, val) def __iter__(self): return _DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
其中:RandomSampler,BatchSampler已經(jīng)得到了采用batch數(shù)據(jù)的index索引;yield batch機制已經(jīng)在?。?!
class RandomSampler(Sampler): r"""Samples elements randomly, without replacement. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.data_source = data_source def __iter__(self): return iter(torch.randperm(len(self.data_source)).tolist()) def __len__(self): return len(self.data_source)
class BatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler, batch_size, drop_last): if not isinstance(sampler, Sampler): raise ValueError("sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}" .format(sampler)) if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError("batch_size should be a positive integeral value, " "but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): raise ValueError("drop_last should be a boolean value, but got " "drop_last={}".format(drop_last)) self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size
- 其中 _DataLoaderIter(self)輸入為一個dataloader對象;如果num_workers=0很好理解,num_workers!=0引入多線程機制,加速數(shù)據(jù)加載過程;
- 沒有多線程時:batch = self.collate_fn([self.dataset[i] for i in indices])進行將index轉(zhuǎn)化為data數(shù)據(jù),返回(image,label);self.dataset[i]會調(diào)用datasets對象的
__getitem__()方法
- 多線程下,會為每個線程創(chuàng)建一個索引隊列index_queues;共享一個worker_result_queue數(shù)據(jù)隊列!在_worker_loop方法中加載數(shù)據(jù);
class _DataLoaderIter(object): r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() if self.pin_memory: maybe_device_id = torch.cuda.current_device() else: # do not initialize cuda context if not necessary maybe_device_id = None self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, maybe_device_id)) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return len(self.batch_sampler) def _get_batch(self): if self.timeout > 0: try: return self.data_queue.get(timeout=self.timeout) except queue.Empty: raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) else: return self.data_queue.get() def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self._get_batch() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal # module's handlers are executed after Python returns from C low-level # handlers, likely when the same fatal signal happened again already. # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 _set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while True: try: r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL) except queue.Empty: if watchdog.is_alive(): continue else: break if r is None: break idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples
- 需要對隊列操作,緩存數(shù)據(jù),使得加載提速!
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python 實現(xiàn)求解字符串集的最長公共前綴方法
今天小編就為大家分享一篇python 實現(xiàn)求解字符串集的最長公共前綴方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07Python操作PDF文件之實現(xiàn)A3頁面轉(zhuǎn)A4
這篇文章主要為大家詳細介紹了Python操作PDF文件之實現(xiàn)A3頁面轉(zhuǎn)A4功能的相關(guān)資料,文中的示例代碼講解詳細,感興趣的小伙伴可以了解一下2022-11-11Django之使用內(nèi)置函數(shù)和celery發(fā)郵件的方法示例
這篇文章主要介紹了Django之使用內(nèi)置函數(shù)和celery發(fā)郵件的方法示例,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習或者工作具有一定的參考學(xué)習價值,需要的朋友們下面隨著小編來一起學(xué)習學(xué)習吧2019-09-09Python的“二維”字典 (two-dimension dictionary)定義與實現(xiàn)方法
這篇文章主要介紹了Python的“二維”字典 (two-dimension dictionary)定義與實現(xiàn)方法,結(jié)合實例形式分析了Python模擬實現(xiàn)類似二維數(shù)組形式的二維字典功能,需要的朋友可以參考下2016-04-04使用python如何提取JSON數(shù)據(jù)指定內(nèi)容
這篇文章主要介紹了使用python如何提取JSON數(shù)據(jù)指定內(nèi)容,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-07-07python pytorch模型轉(zhuǎn)onnx模型的全過程(多輸入+動態(tài)維度)
這篇文章主要介紹了python pytorch模型轉(zhuǎn)onnx模型的全過程(多輸入+動態(tài)維度),本文給大家記錄記錄了pt文件轉(zhuǎn)onnx全過程,簡單的修改即可應(yīng)用,結(jié)合實例代碼給大家介紹的非常詳細,感興趣的朋友一起看看吧2024-03-03Python使用PyPDF2庫實現(xiàn)向PDF文件中插入內(nèi)容
Python的PyPDF2庫是一個強大的工具,它允許我們方便地操作PDF文件,包括合并、拆分、旋轉(zhuǎn)頁面等操作,下面我們就來看看如何使用PyPDF2庫實現(xiàn)向PDF文件中插入內(nèi)容吧2024-04-04