Pytorch自定義Dataset和DataLoader去除不存在和空數(shù)據(jù)的操作
【源碼GitHub地址】:點擊進入
1. 問題描述
之前寫了一篇關(guān)于《pytorch Dataset, DataLoader產(chǎn)生自定義的訓練數(shù)據(jù)》的博客,但存在一個問題,我們不能在Dataset做一些數(shù)據(jù)清理,如果我們傳遞給Dataset數(shù)據(jù),本身存在問題,那么迭代過程肯定出錯的。
比如我把很多圖片路徑都傳遞給Dataset,如果圖片路徑都是正確的,且圖片都存在也沒有損壞,那顯然運行是沒有問題的;
但倘若傳遞給Dataset的圖片路徑有些圖片是不存在,這時你通過Dataset讀取圖片數(shù)據(jù),然后再迭代返回,就會出現(xiàn)類似如下的錯誤:
File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 68, in <listcomp> return [default_collate(samples) for samples in transposed]
File "D:\ProgramData\Anaconda3\envs\pytorch-py36\lib\site-packages\torch\utils\data\_utils\collate.py", line 70, in default_collate
raise TypeError((error_msg_fmt.format(type(batch[0])))) TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'NoneType'>
2. 一般的解決方法
一般的解決方法也很簡單粗暴,就是在傳遞數(shù)據(jù)給Dataset前,就做數(shù)據(jù)清理,把不存在的圖片,損壞的數(shù)據(jù)都提前清理掉。
是的,這個是最簡單粗暴的。
3. 另一種解決方法:自定義返回數(shù)據(jù)的規(guī)則:collate_fn()校對函數(shù)
我們希望不管傳遞什么處理給Dataset,Dataset都進行處理,如果不存在或者異常,就返回None,而在DataLoader時,對于不存為None的數(shù)據(jù),都去除掉。
這樣就保證在迭代過程中,DataLoader獲得batch數(shù)據(jù)都是正確的。
比如讀取batch_size=5的圖片數(shù)據(jù),如果其中有1個(或者多個)圖片是不存在,那么返回的batch應該把不存在的數(shù)據(jù)過濾掉,即返回5-1=4大小的batch的數(shù)據(jù)。
是的,我要實現(xiàn)的就是這個功能:返回的batch數(shù)據(jù)會自定清理掉不合法的數(shù)據(jù)。
3.1 Pytorch數(shù)據(jù)處理函數(shù):Dataset和 DataLoader
Pytorch有兩個數(shù)據(jù)處理函數(shù):Dataset和 DataLoader
from torch.utils.data import Dataset, DataLoader
其中Dataset用于定義數(shù)據(jù)的讀取和預處理操作,而DataLoader用于加載并產(chǎn)生批訓練數(shù)據(jù)。
torch.utils.data.DataLoader參數(shù)說明:
DataLoader(object)可用參數(shù):
1、dataset(Dataset)
傳入的數(shù)據(jù)集
2、batch_size(int, optional)
每個batch有多少個樣本
3、shuffle(bool, optional)
在每個epoch開始的時候,對數(shù)據(jù)進行重新排序
4、sampler(Sampler, optional)
自定義從數(shù)據(jù)集中取樣本的策略,如果指定這個參數(shù),那么shuffle必須為False
5、batch_sampler(Sampler, optional)
與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
6、num_workers (int, optional)
這個參數(shù)決定了有幾個進程來處理data loading。0意味著所有的數(shù)據(jù)都會被load進主進程。(默認為0)
7、collate_fn (callable, optional)
將一個list的sample組成一個mini-batch的函數(shù)
8、pin_memory (bool, optional)
如果設(shè)置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.
9、drop_last (bool, optional)
如果設(shè)置為True:這個是對最后的未完成的batch來說的,比如你的batch_size設(shè)置為64,而一個epoch只有100個樣本,那么訓練的時候后面的36個就被扔掉了。 如果為False(默認),那么會繼續(xù)正常執(zhí)行,只是最后的batch_size會小一點。
10、timeout(numeric, optional)
如果是正數(shù),表明等待從worker進程中收集一個batch等待的時間,若超出設(shè)定的時間還沒有收集到,那就不收集這個內(nèi)容了。這個numeric應總是大于等于0。默認為0
11、worker_init_fn (callable, optional)
每個worker初始化函數(shù) If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
我們要用到的是collate_fn()回調(diào)函數(shù)
3.2 自定義collate_fn()函數(shù):
torch.utils.data.DataLoader的collate_fn()用于設(shè)置batch數(shù)據(jù)拼接方式,默認是default_collate函數(shù),但當batch中含有None等數(shù)據(jù)時,默認的default_collate校隊方法會出現(xiàn)錯誤。因此,我們需要自定義collate_fn()函數(shù):
方法也很簡單:只需在原來的default_collate函數(shù)中添加下面幾句代碼:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了。
# 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了 if isinstance(batch, list): batch = [(image, image_id) for (image, image_id) in batch if image is not None] if batch==[]: return (None,None)
dataset_collate.py:
# -*-coding: utf-8 -*- """ @Project: pytorch-learning-tutorials @File : dataset_collate.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2019-06-07 17:09:13 """ r""""Contains definitions of the methods used by the _DataLoaderIter workers to collate samples fetched from dataset into Tensor(s). These **needs** to be in global scope since Py2 doesn't support serializing static methods. """ import torch import re from torch._six import container_abcs, string_classes, int_classes _use_shared_memory = False r"""Whether to use shared memory in default_collate""" np_str_obj_array_pattern = re.compile(r'[SaUO]') error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" numpy_type_map = { 'float64': torch.DoubleTensor, 'float32': torch.FloatTensor, 'float16': torch.HalfTensor, 'int64': torch.LongTensor, 'int32': torch.IntTensor, 'int16': torch.ShortTensor, 'int8': torch.CharTensor, 'uint8': torch.ByteTensor, } def collate_fn(batch): ''' collate_fn (callable, optional): merges a list of samples to form a mini-batch. 該函數(shù)參考touch的default_collate函數(shù),也是DataLoader的默認的校對方法,當batch中含有None等數(shù)據(jù)時, 默認的default_collate校隊方法會出現(xiàn)錯誤 一種的解決方法是: 判斷batch中image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了 :param batch: :return: ''' r"""Puts each data field into a tensor with outer dimension batch size""" # 這里添加:判斷image是否為None,如果為None,則在原來的batch中清除掉,這樣就可以在迭代中避免出錯了 if isinstance(batch, list): batch = [(image, image_id) for (image, image_id) in batch if image is not None] if batch==[]: return (None,None) elem_type = type(batch[0]) if isinstance(batch[0], torch.Tensor): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(error_msg_fmt.format(elem.dtype)) return collate_fn([torch.from_numpy(b) for b in batch]) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(batch[0], int_classes): return torch.tensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], container_abcs.Mapping): return {key: collate_fn([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple return type(batch[0])(*(collate_fn(samples) for samples in zip(*batch))) elif isinstance(batch[0], container_abcs.Sequence): transposed = zip(*batch)#ok return [collate_fn(samples) for samples in transposed] raise TypeError((error_msg_fmt.format(type(batch[0]))))
測試方法:
# -*-coding: utf-8 -*- """ @Project: pytorch-learning-tutorials @File : dataset.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2019-03-07 18:45:06 """ import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import dataset_collate import os import cv2 from PIL import Image def read_image(path,mode='RGB'): ''' :param path: :param mode: RGB or L :return: ''' return Image.open(path).convert(mode) class TorchDataset(Dataset): def __init__(self, image_id_list, image_dir, resize_height=256, resize_width=256, repeat=1, transform=None): ''' :param filename: 數(shù)據(jù)文件TXT:格式:imge_name.jpg label1_id labe2_id :param image_dir: 圖片路徑:image_dir+imge_name.jpg構(gòu)成圖片的完整路徑 :param resize_height 為None時,不進行縮放 :param resize_width 為None時,不進行縮放, PS:當參數(shù)resize_height或resize_width其中一個為None時,可實現(xiàn)等比例縮放 :param repeat: 所有樣本數(shù)據(jù)重復次數(shù),默認循環(huán)一次,當repeat為None時,表示無限循環(huán)<sys.maxsize :param transform:預處理 ''' self.image_dir = image_dir self.image_id_list=image_id_list self.len = len(image_id_list) self.repeat = repeat self.resize_height = resize_height self.resize_width = resize_width self.transform= transform def __getitem__(self, i): index = i % self.len # print("i={},index={}".format(i, index)) image_id = self.image_id_list[index] image_path = os.path.join(self.image_dir, image_id) img = self.load_data(image_path) if img is None: return None,image_id img = self.data_preproccess(img) return img,image_id def __len__(self): if self.repeat == None: data_len = 10000000 else: data_len = len(self.image_id_list) * self.repeat return data_len def load_data(self, path): ''' 加載數(shù)據(jù) :param path: :param resize_height: :param resize_width: :param normalization: 是否歸一化 :return: ''' try: image = read_image(path) except Exception as e: image=None print(e) # image = image_processing.read_image(path)#用opencv讀取圖像 return image def data_preproccess(self, data): ''' 數(shù)據(jù)預處理 :param data: :return: ''' if self.transform is not None: data = self.transform(data) return data if __name__=='__main__': resize_height = 224 resize_width = 224 image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] image_dir="../dataset/test_images/images" # 相關(guān)預處理的初始化 '''class torchvision.transforms.ToTensor把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數(shù)據(jù) # 轉(zhuǎn)換成shape=(C,H,W)的像素數(shù)據(jù),并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。 ''' train_transform = transforms.Compose([ transforms.Resize(size=(resize_height, resize_width)), # transforms.RandomHorizontalFlip(),#隨機翻轉(zhuǎn)圖像 transforms.RandomCrop(size=(resize_height, resize_width), padding=4), # 隨機裁剪 transforms.ToTensor(), # 吧shape=(H,W,C)->換成shape=(C,H,W),并且歸一化到[0.0, 1.0]的torch.FloatTensor類型 # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))#給定均值(R,G,B) 方差(R,G,B),將會把Tensor正則化 ]) epoch_num=2 #總樣本循環(huán)次數(shù) batch_size=5 #訓練時的一組數(shù)據(jù)的大小 train_data_nums=10 max_iterate=int((train_data_nums+batch_size-1)/batch_size*epoch_num) #總迭代次數(shù) train_data = TorchDataset(image_id_list=image_id_list, image_dir=image_dir, resize_height=resize_height, resize_width=resize_width, repeat=1, transform=train_transform) # 使用默認的default_collate會報錯 # train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False) # 使用自定義的collate_fn train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=False, collate_fn=dataset_collate.collate_fn) # [1]使用epoch方法迭代,TorchDataset的參數(shù)repeat=1 for epoch in range(epoch_num): for step,(batch_image, batch_label) in enumerate(train_loader): if batch_image is None and batch_label is None: print("batch_image:{},batch_label:{}".format(batch_image, batch_label)) continue image=batch_image[0,:] image=image.numpy()#image=np.array(image) image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c] cv2.imshow("image",image) cv2.waitKey(2000) print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label)) # batch_x, batch_y = Variable(batch_x), Variable(batch_y)
輸出結(jié)果說明:
batch_size=5,輸入圖片列表image_id_list=["1.jpg","ddd.jpg","111.jpg","3.jpg","4.jpg","5.jpg","6.jpg","7.jpg","8.jpg","9.jpg"] ,其中"ddd.jpg","111.jpg"是不存在的,resize_width=224,正常情況下返回的數(shù)據(jù)應該是torch.Size([5, 3, 224, 224]),但由于"ddd.jpg","111.jpg"不存在,被過濾掉了,所以第一個batch的維度變?yōu)閠orch.Size([3, 3, 224, 224])
[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'
[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')
batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')
[Errno 2] No such file or directory: '../dataset/test_images/images\\ddd.jpg'
[Errno 2] No such file or directory: '../dataset/test_images/images\\111.jpg'
batch_image.shape:torch.Size([3, 3, 224, 224]),batch_label:('1.jpg', '3.jpg', '4.jpg')
batch_image.shape:torch.Size([5, 3, 224, 224]),batch_label:('5.jpg', '6.jpg', '7.jpg', '8.jpg', '9.jpg')
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
相關(guān)文章
Python的Django框架中消息通知的計數(shù)器實現(xiàn)教程
通知的計數(shù)器非常有用,新通知時+1和讀過通知后的-1是最基本的功能,這里我們就來看一下Python的Django框架中消息通知的計數(shù)器實現(xiàn)教程2016-06-06python+matplotlib實現(xiàn)禮盒柱狀圖實例代碼
這篇文章主要介紹了python+matplotlib實現(xiàn)禮盒柱狀圖實例代碼,具有一定借鑒價值,需要的朋友可以參考下2018-01-01Python實現(xiàn)統(tǒng)計給定列表中指定數(shù)字出現(xiàn)次數(shù)的方法
這篇文章主要介紹了Python實現(xiàn)統(tǒng)計給定列表中指定數(shù)字出現(xiàn)次數(shù)的方法,涉及Python針對列表的簡單遍歷、計算相關(guān)操作技巧,需要的朋友可以參考下2018-04-04使用Python向DataFrame中指定位置添加一列或多列的方法
今天小編就為大家分享一篇使用Python向DataFrame中指定位置添加一列或多列的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01