pytorch如何自定義數(shù)據(jù)集
自定義數(shù)據(jù)
數(shù)據(jù)傳遞機制
我們首先回顧識別手寫數(shù)字的程序:
... Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,) dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True) ... for epoch in range(EPOCH): for i, (image, label) in enumerate(dataloader): ...
從上面的程序,我們可以知道,在PyTorch中,數(shù)據(jù)傳遞機制是這樣的:
1.創(chuàng)建Dataset
2.Dataset傳遞給DataLoader
3.DataLoader迭代產(chǎn)生訓練數(shù)據(jù)提供給模型。
總結這個數(shù)據(jù)傳遞機制就是,Dataset負責建立索引到樣本的映射,DataLoader負責以特定的方式從數(shù)據(jù)集中迭代的產(chǎn)生一個個batch的樣本集合。在enumerate過程中實際上是dataloader按照其參數(shù)sampler規(guī)定的策略調(diào)用了其dataset的getitem方法(下文中將介紹該方法)。
在上面的識別手寫數(shù)字的例子中,數(shù)據(jù)集是直接下載的,但如果我們自己收集了一些數(shù)據(jù),存在電腦文件夾里,我們該如何把這些數(shù)據(jù)變?yōu)榭梢栽赑yTorch框架下進行神經(jīng)網(wǎng)絡訓練的數(shù)據(jù)集呢,即如何自定義數(shù)據(jù)集呢?
PyTorch中Dataset,DataLoader,Sample的關系
PyTorch中Dataset,DataLoader,Sampler的關系可以用下圖概括:
用文字表達就是:Dataloader中包含Sampler和Dataset,Sampler產(chǎn)生索引,Dataset拿著這個索引在數(shù)據(jù)集文件夾中找到對應的樣本(每個樣本對應一個索引,就像列表中每個元素對應一個索引),并給該樣本配置上標簽,最后返回(樣本+標簽)給調(diào)用方。
在enumerate過程中,Dataloader按照其參數(shù)BatchSampler規(guī)定的策略調(diào)用其Dataset的getitem方法batchsize次,得到一個batch,該batch中既包含樣本,也包含相應的標簽。
自定義數(shù)據(jù)集
torch.utils.data.Dataset 是一個表示數(shù)據(jù)集的抽象類。任何自定義的數(shù)據(jù)集都需要繼承這個類并覆寫相關方法。所謂數(shù)據(jù)集,其實就是一個負責處理索引(index)到樣本(sample)映射的一個類(class)。Pytorch提供兩種數(shù)據(jù)集: Map式數(shù)據(jù)集 Iterable式數(shù)據(jù)集。這里我們只介紹前者。
一個Map式的數(shù)據(jù)集必須要重寫getitem(self, index)、 len(self) 兩個內(nèi)建方法,用來表示從索引到樣本的映射(Map)。這樣一個數(shù)據(jù)集dataset,舉個例子,當使用dataset[idx]命令時,可以在你的硬盤中讀取數(shù)據(jù)集中第idx張圖片以及其標簽(如果有的話); len(dataset)則會返回這個數(shù)據(jù)集的容量。
自定義數(shù)據(jù)集類的范式大致是這樣的:
class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset def __init__(self): # TODO # 1. Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #這里需要注意的是,第一步:read one data,是一個data point pass def __len__(self): # You should change 0 to the total size of your dataset. return 0
關于Dataset API的官網(wǎng)介紹https://pytorch.org/docs/stable/data.html#dataset-types:
Dataset類的使用:所有的類都應該是此類的子類(也就是說應該繼承該類)。所有的子類都要重寫(override) len(), getitem()。
Ø__len()__ : 此方法應該提供數(shù)據(jù)集的大小(容量)
Ø__getitem()__ : 此方法應該提供支持下標索引方式訪問數(shù)據(jù)集。
DataLoader類的使用如下:
根據(jù)這個方式,我們舉一個例子。
實例1
從kaggle官網(wǎng)下載dogsVScats的數(shù)據(jù)集(百度網(wǎng)盤下載鏈接見文末),該數(shù)據(jù)集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號:
sampleSubmission.csv中的內(nèi)容如下:
我們把其中前10000張貓的圖片和10000張狗的圖片作為訓練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗證集。貓的label記為0,狗的label記為1。因為圖片大小不一,所以,我們需要對圖像進行transform。
# -*- coding: UTF-8 -*- import matplotlib.pyplot as plt import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image """ 如果代碼執(zhí)行的時候出現(xiàn): OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized. OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to continue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/. 解決辦法是加上: import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" """ import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" image_transform = transforms.Compose([ transforms.Resize(256), # 把圖片resize為256*256 transforms.RandomCrop(224), # 隨機裁剪224*224 transforms.RandomHorizontalFlip(), # 水平翻轉(zhuǎn) transforms.ToTensor(), # 將圖像轉(zhuǎn)為Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標準化 ]) # 創(chuàng)建一個叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset class DogVsCatDataset(Dataset): def __init__(self, root_dir, train=True, transform=None): """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.root_dir = root_dir self.img_path = os.listdir(self.root_dir) if train: # 圖片數(shù)據(jù)中有類似:dog.12499.jpg的圖片共12499張。 # x.split('.')[1] 就是文件名dog.12473.jpg中的序號部分,也是圖片的編號 self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path)) # 劃分訓練集和驗證集 else: # 序號大于10000的編號 self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path)) self.transform = transform def __len__(self): return len(self.img_path) def __getitem__(self, idx): image = Image.open(os.path.join(self.root_dir, self.img_path[idx])) label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1 # label, 貓為0,狗為1 if self.transform: image = self.transform(image) label = torch.from_numpy(np.array([label])) return image, label # 來測試一下 if __name__ == '__main__': catanddog_dataset = DogVsCatDataset(root_dir='E:/BaiduNetdiskDownload/kaggle/train', train=False, transform=image_transform) # num_workers=4表示用4個線程讀取數(shù)據(jù) train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4) # iter()函數(shù)把train_loader變?yōu)榈?,然后調(diào)用迭代器的next()方法 image, label = iter(train_loader).next() sample = image[0].squeeze() sample = sample.permute((1, 2, 0)).numpy() sample *= [0.229, 0.224, 0.225] sample += [0.485, 0.456, 0.406] sample = np.clip(sample, 0, 1) plt.imshow(sample) plt.show() print('Label is: {}'.format(label[0].numpy()))
運行結果:
實例2
收集圖像樣本
以簡單的貓狗二分類為例,可以在網(wǎng)上下載一些貓狗圖片。創(chuàng)建以下目錄:
ldata -----------------根目錄
ldata/test -----------------測試集
ldata/train -----------------訓練集
ldata/val ------------------驗證集
在test/train/val之下在校分別創(chuàng)建2個文件夾,dog,cat
cat,dog文件夾下分別存放2類圖像:
之后寫一個簡單的python腳本,生成txt文件,用于指明每個圖像和標簽的對應關系。
格式:
/cat/1.jpg 0
/dog/1.jpg 1
…
如圖:
至此,樣本集的收集以及簡單歸類完成。
實現(xiàn)
使用到python package
python package | 目錄 |
---|---|
numpy | 矩陣操作,對圖像進行轉(zhuǎn)置 |
skimage | 圖像處理,圖像I/O,圖像變換 |
matplotlib | 圖像的顯示,可視化 |
os | 一些文件查找操作 |
torch | pytorch |
torchvision | pytorch |
代碼
# -*- coding: UTF-8 -*- """ 本案例來自:http://www.dbjr.com.cn/article/199360.htm """ import numpy as np from skimage import io from skimage import transform import matplotlib.pyplot as plt import os import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision.transforms import transforms from torchvision.utils import make_grid """ 第一步: 定義一個子類,繼承Dataset類,重寫__len()__,__getitem()__方法。 細節(jié): 1、數(shù)據(jù)集中一個一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。 2、圖像的讀取:采用skimage.io進行讀取,讀取之后的結果為numpy.ndarray形式。 3、圖像變換:transform參數(shù) """ class MyDataset(Dataset): def __init__(self, root_dir, names_file, transform=None): self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(self.names_file + 'does not exist!') file = open(self.names_file) for f in file: self.names_list.append(f) self.size += 1 def __len__(self): return self.size def __getitem__(self, idx): image_path = self.root_dir + self.names_list[idx].split(' ')[0] if not os.path.isfile(image_path): print(image_path + 'does not exists!') return None image = io.imread(image_path) # use skitimage label = int(self.names_list[idx].split(' ')[1]) sample = {'image': image, 'label': label} if self.transform: sample = self.transform(sample) return sample """ 第二步 實例化一個對象,并讀取和顯示數(shù)據(jù)集 """ train_dataset = MyDataset(root_dir='./data/train', names_file='./data/train/train.txt', transform=None) plt.figure() for (cnt, i) in enumerate(train_dataset): image = i['image'] label = i['label'] ax = plt.subplot(4, 4, cnt + 1) ax.axis('off') ax.imshow(image) ax.set_title('label {}'.format(label)) plt.pause(0.001) if cnt == 15: break """ 第三步(可選optional) 對數(shù)據(jù)集進行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數(shù)據(jù)歸一化。另一方面,可 以通過變換進行數(shù)據(jù)增加data argument 關于pytorch中的變換transforms,請參考該系列之前的文章。 由于數(shù)據(jù)集中樣本采用字典dicts形式表示。 因此不能直接調(diào)用torchvision.transofrms中的方法。 本實驗只進行尺寸歸一化Resize, 數(shù)據(jù)類型變換ToTensor操作。 Resize """ # 變換Resize class Resize(object): def __init__(self, output_size: tuple): self.output_size = output_size def __call__(self, sample): # 圖像 image = sample['image'] # 使用skitimage.transform對圖像進行縮放 image_new = transform.resize(image, self.output_size) return {'image': image_new, 'label': sample['label']} # ToTensor ## 變換ToTensor class ToTensor(object): def __call__(self, sample): image = sample['image'] image_new = np.transpose(image, (2, 0, 1)) return {'image': torch.from_numpy(image_new), 'label': sample['label']} """ 第四步:對整個數(shù)據(jù)集應用變換 細節(jié):transformers.Compose()將不同的幾個組合起來。先進行Resize,再進行ToTensor """ # 對原始的訓練數(shù)據(jù)集進行變換 transformed_trainset = MyDataset(root_dir='./data/train', names_file='./data/train/train.txt', transform=transforms.Compose([ Resize((224, 224)), ToTensor()])) """ 第五步:使用DataLoader進行包裝 為何要使用DataLoader? 1、深度學習的輸入是mini_batch形式 2、樣本加載時候可能需要隨機打亂順序,shuffle操作 3、樣本加載需要采用多線程 pytorch提供的DataLoader封裝了上述的功能,這樣使用起來更方便。 """ # 使用DataLoader可以利用多線程,batch,shuffle等 # 使用DataLoader可以利用多線程,batch,shuffle等 trainset_dataloader = DataLoader(dataset=transformed_trainset, batch_size=4, shuffle=True, num_workers=4) # 可視化 def show_images_batch(sample_batched): images_batch, labels_batch = \ sample_batched['image'], sample_batched['label'] grid = make_grid(images_batch) plt.imshow(grid.numpy().transpose(1, 2, 0)) # sample_batch: Tensor , NxCxHxW plt.figure() for i_batch, sample_batch in enumerate(trainset_dataloader): show_images_batch(sample_batch) plt.axis('off') plt.ioff() plt.show() plt.show() """ 通過DataLoader包裝之后,樣本以min_batch形式輸出,而且進行了隨機打亂順序。 至此,自定義數(shù)據(jù)集的完整流程已經(jīng)實現(xiàn),test, val集只需要改路徑即可。 """
輸出類似:
補充:
更簡單的方法
上述繼承Dataset,重寫__len()__,__getitem()是通用的方法,過程相對繁瑣。對于簡單的分類數(shù)據(jù)集,pytorch中提供了更簡便的方式----ImageFolder。
如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。仍然以cat, dog二分類數(shù)據(jù)集為例:
文件結構:
Code
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets import matplotlib.pyplot as plt import numpy as np # https://pytorch.org/tutorials/beginner/data_loading_tutorial.html # data_transform = transforms.Compose([ # transforms.RandomResizedCrop(224), # transforms.RandomHorizontalFlip(), # transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # ]) data_transform = transforms.Compose([ transforms.Resize((224,224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform) train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True, num_workers=4) def show_batch_images(sample_batch): labels_batch = sample_batch[1] images_batch = sample_batch[0] for i in range(4): label_ = labels_batch[i].item() image_ = np.transpose(images_batch[i], (1, 2, 0)) ax = plt.subplot(1, 4, i + 1) ax.imshow(image_) ax.set_title(str(label_)) ax.axis('off') plt.pause(0.01) plt.figure() for i_batch, sample_batch in enumerate(train_dataloader): show_batch_images(sample_batch) plt.show()
由于 train 目錄下只有2個文件夾,分別為cat, dog, 因此ImageFolder安裝順序?qū)at使用標簽0, dog使用標簽1。(輸出類似:)
參考文章
https://www.cnblogs.com/picassooo/p/12846617.html
http://www.dbjr.com.cn/article/199360.htm
到此這篇關于pytorch自定義數(shù)據(jù)集的文章就介紹到這了,更多相關pytorch自定義數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
基于Python-turtle庫繪制路飛的草帽骷髏旗、美國隊長的盾牌、高達的源碼
這篇文章主要介紹了基于Python-turtle庫繪制路飛的草帽骷髏旗、美國隊長的盾牌、高達的源碼,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-02-02