欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch如何自定義數(shù)據(jù)集

 更新時(shí)間:2024年01月30日 10:03:44   作者:涂作權(quán)  
這篇文章主要介紹了pytorch自定義數(shù)據(jù)集,在識(shí)別手寫數(shù)字的例子中,數(shù)據(jù)集是直接下載的,但如果我們自己收集了一些數(shù)據(jù),存在電腦文件夾里,我們?cè)撊绾伟堰@些數(shù)據(jù)變?yōu)榭梢栽赑yTorch框架下進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)集呢,即如何自定義數(shù)據(jù)集呢,需要的朋友可以參考下

自定義數(shù)據(jù)

數(shù)據(jù)傳遞機(jī)制

我們首先回顧識(shí)別手寫數(shù)字的程序:

...
Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,)
dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True)
...
for epoch in range(EPOCH):
    for i, (image, label) in enumerate(dataloader):
        ...

從上面的程序,我們可以知道,在PyTorch中,數(shù)據(jù)傳遞機(jī)制是這樣的:
1.創(chuàng)建Dataset
2.Dataset傳遞給DataLoader
3.DataLoader迭代產(chǎn)生訓(xùn)練數(shù)據(jù)提供給模型。
總結(jié)這個(gè)數(shù)據(jù)傳遞機(jī)制就是,Dataset負(fù)責(zé)建立索引到樣本的映射,DataLoader負(fù)責(zé)以特定的方式從數(shù)據(jù)集中迭代的產(chǎn)生一個(gè)個(gè)batch的樣本集合。在enumerate過程中實(shí)際上是dataloader按照其參數(shù)sampler規(guī)定的策略調(diào)用了其dataset的getitem方法(下文中將介紹該方法)。

在上面的識(shí)別手寫數(shù)字的例子中,數(shù)據(jù)集是直接下載的,但如果我們自己收集了一些數(shù)據(jù),存在電腦文件夾里,我們?cè)撊绾伟堰@些數(shù)據(jù)變?yōu)榭梢栽赑yTorch框架下進(jìn)行神經(jīng)網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)集呢,即如何自定義數(shù)據(jù)集呢?

PyTorch中Dataset,DataLoader,Sample的關(guān)系

PyTorch中Dataset,DataLoader,Sampler的關(guān)系可以用下圖概括:

在這里插入圖片描述

用文字表達(dá)就是:Dataloader中包含Sampler和Dataset,Sampler產(chǎn)生索引,Dataset拿著這個(gè)索引在數(shù)據(jù)集文件夾中找到對(duì)應(yīng)的樣本(每個(gè)樣本對(duì)應(yīng)一個(gè)索引,就像列表中每個(gè)元素對(duì)應(yīng)一個(gè)索引),并給該樣本配置上標(biāo)簽,最后返回(樣本+標(biāo)簽)給調(diào)用方。

在enumerate過程中,Dataloader按照其參數(shù)BatchSampler規(guī)定的策略調(diào)用其Dataset的getitem方法batchsize次,得到一個(gè)batch,該batch中既包含樣本,也包含相應(yīng)的標(biāo)簽。

自定義數(shù)據(jù)集

torch.utils.data.Dataset 是一個(gè)表示數(shù)據(jù)集的抽象類。任何自定義的數(shù)據(jù)集都需要繼承這個(gè)類并覆寫相關(guān)方法。所謂數(shù)據(jù)集,其實(shí)就是一個(gè)負(fù)責(zé)處理索引(index)到樣本(sample)映射的一個(gè)類(class)。Pytorch提供兩種數(shù)據(jù)集: Map式數(shù)據(jù)集 Iterable式數(shù)據(jù)集。這里我們只介紹前者。

一個(gè)Map式的數(shù)據(jù)集必須要重寫getitem(self, index)、 len(self) 兩個(gè)內(nèi)建方法,用來表示從索引到樣本的映射(Map)。這樣一個(gè)數(shù)據(jù)集dataset,舉個(gè)例子,當(dāng)使用dataset[idx]命令時(shí),可以在你的硬盤中讀取數(shù)據(jù)集中第idx張圖片以及其標(biāo)簽(如果有的話); len(dataset)則會(huì)返回這個(gè)數(shù)據(jù)集的容量。

自定義數(shù)據(jù)集類的范式大致是這樣的:

class CustomDataset(torch.utils.data.Dataset):#需要繼承torch.utils.data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #這里需要注意的是,第一步:read one data,是一個(gè)data point
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

關(guān)于Dataset API的官網(wǎng)介紹https://pytorch.org/docs/stable/data.html#dataset-types:

在這里插入圖片描述

Dataset類的使用:所有的類都應(yīng)該是此類的子類(也就是說應(yīng)該繼承該類)。所有的子類都要重寫(override) len(), getitem()。
Ø__len()__ : 此方法應(yīng)該提供數(shù)據(jù)集的大?。ㄈ萘浚?br />Ø__getitem()__ : 此方法應(yīng)該提供支持下標(biāo)索引方式訪問數(shù)據(jù)集。

DataLoader類的使用如下:

在這里插入圖片描述

在這里插入圖片描述

在這里插入圖片描述

根據(jù)這個(gè)方式,我們舉一個(gè)例子。

實(shí)例1

從kaggle官網(wǎng)下載dogsVScats的數(shù)據(jù)集(百度網(wǎng)盤下載鏈接見文末),該數(shù)據(jù)集包含test1文件夾和train文件夾,train文件夾中包含12500張貓的圖片和12500張狗的圖片,圖片的文件名中帶序號(hào):

在這里插入圖片描述

sampleSubmission.csv中的內(nèi)容如下:

在這里插入圖片描述

在這里插入圖片描述

我們把其中前10000張貓的圖片和10000張狗的圖片作為訓(xùn)練集,把后面的2500張貓的圖片和2500張狗的圖片作為驗(yàn)證集。貓的label記為0,狗的label記為1。因?yàn)閳D片大小不一,所以,我們需要對(duì)圖像進(jìn)行transform。

# -*- coding: UTF-8 -*-
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
"""
如果代碼執(zhí)行的時(shí)候出現(xiàn):
OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. 
That is dangerous, since it can degrade performance or cause incorrect results. The best thing 
to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. 
by avoiding static linking of the OpenMP runtime in any library. As an unsafe, unsupported, 
undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow 
the program to continue to execute, but that may cause crashes or silently produce incorrect results. 
For more information, please see http://www.intel.com/software/products/support/.
解決辦法是加上:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
"""
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
image_transform = transforms.Compose([
    transforms.Resize(256),                              # 把圖片resize為256*256
    transforms.RandomCrop(224),                          # 隨機(jī)裁剪224*224
    transforms.RandomHorizontalFlip(),                   # 水平翻轉(zhuǎn)
    transforms.ToTensor(),                               # 將圖像轉(zhuǎn)為Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標(biāo)準(zhǔn)化
])
# 創(chuàng)建一個(gè)叫做DogVsCatDataset的Dataset,繼承自父類torch.utils.data.Dataset
class DogVsCatDataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.img_path = os.listdir(self.root_dir)
        if train:
            # 圖片數(shù)據(jù)中有類似:dog.12499.jpg的圖片共12499張。
            # x.split('.')[1] 就是文件名dog.12473.jpg中的序號(hào)部分,也是圖片的編號(hào)
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))  # 劃分訓(xùn)練集和驗(yàn)證集
        else:
            # 序號(hào)大于10000的編號(hào)
            self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
        self.transform = transform
    def __len__(self):
        return len(self.img_path)
    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
        label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1  # label, 貓為0,狗為1
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array([label]))
        return image, label
# 來測(cè)試一下
if __name__ == '__main__':
    catanddog_dataset = DogVsCatDataset(root_dir='E:/BaiduNetdiskDownload/kaggle/train',
                                        train=False,
                                        transform=image_transform)
    # num_workers=4表示用4個(gè)線程讀取數(shù)據(jù)
    train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)
    # iter()函數(shù)把train_loader變?yōu)榈?,然后調(diào)用迭代器的next()方法
    image, label = iter(train_loader).next()
    sample = image[0].squeeze()
    sample = sample.permute((1, 2, 0)).numpy()
    sample *= [0.229, 0.224, 0.225]
    sample += [0.485, 0.456, 0.406]
    sample = np.clip(sample, 0, 1)
    plt.imshow(sample)
    plt.show()
    print('Label is: {}'.format(label[0].numpy()))

運(yùn)行結(jié)果:

在這里插入圖片描述

實(shí)例2

收集圖像樣本

以簡(jiǎn)單的貓狗二分類為例,可以在網(wǎng)上下載一些貓狗圖片。創(chuàng)建以下目錄:
ldata -----------------根目錄
ldata/test -----------------測(cè)試集
ldata/train -----------------訓(xùn)練集
ldata/val ------------------驗(yàn)證集

在這里插入圖片描述

在test/train/val之下在校分別創(chuàng)建2個(gè)文件夾,dog,cat

在這里插入圖片描述

cat,dog文件夾下分別存放2類圖像:

在這里插入圖片描述

之后寫一個(gè)簡(jiǎn)單的python腳本,生成txt文件,用于指明每個(gè)圖像和標(biāo)簽的對(duì)應(yīng)關(guān)系。
格式:
/cat/1.jpg 0
/dog/1.jpg 1

如圖:

在這里插入圖片描述

至此,樣本集的收集以及簡(jiǎn)單歸類完成。

實(shí)現(xiàn)

使用到python package

python package目錄
numpy矩陣操作,對(duì)圖像進(jìn)行轉(zhuǎn)置
skimage圖像處理,圖像I/O,圖像變換
matplotlib圖像的顯示,可視化
os一些文件查找操作
torchpytorch
torchvisionpytorch

代碼

# -*- coding: UTF-8 -*-
"""
本案例來自:http://www.dbjr.com.cn/article/199360.htm
"""
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
"""
第一步:
定義一個(gè)子類,繼承Dataset類,重寫__len()__,__getitem()__方法。
細(xì)節(jié):
1、數(shù)據(jù)集中一個(gè)一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。
2、圖像的讀?。翰捎胹kimage.io進(jìn)行讀取,讀取之后的結(jié)果為numpy.ndarray形式。
3、圖像變換:transform參數(shù)
"""
class MyDataset(Dataset):
    def __init__(self, root_dir, names_file, transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []
        if not os.path.isfile(self.names_file):
            print(self.names_file + 'does not exist!')
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        image_path = self.root_dir + self.names_list[idx].split(' ')[0]
        if not os.path.isfile(image_path):
            print(image_path + 'does not exists!')
            return None
        image = io.imread(image_path)  # use skitimage
        label = int(self.names_list[idx].split(' ')[1])
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample
"""
第二步
實(shí)例化一個(gè)對(duì)象,并讀取和顯示數(shù)據(jù)集
"""
train_dataset = MyDataset(root_dir='./data/train',
                          names_file='./data/train/train.txt',
                          transform=None)
plt.figure()
for (cnt, i) in enumerate(train_dataset):
    image = i['image']
    label = i['label']
    ax = plt.subplot(4, 4, cnt + 1)
    ax.axis('off')
    ax.imshow(image)
    ax.set_title('label {}'.format(label))
    plt.pause(0.001)
    if cnt == 15:
        break
"""
第三步(可選optional)
對(duì)數(shù)據(jù)集進(jìn)行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數(shù)據(jù)歸一化。另一方面,可
以通過變換進(jìn)行數(shù)據(jù)增加data argument
關(guān)于pytorch中的變換transforms,請(qǐng)參考該系列之前的文章。
由于數(shù)據(jù)集中樣本采用字典dicts形式表示。 因此不能直接調(diào)用torchvision.transofrms中的方法。
本實(shí)驗(yàn)只進(jìn)行尺寸歸一化Resize, 數(shù)據(jù)類型變換ToTensor操作。
Resize
"""
# 變換Resize
class Resize(object):
    def __init__(self, output_size: tuple):
        self.output_size = output_size
    def __call__(self, sample):
        # 圖像
        image = sample['image']
        # 使用skitimage.transform對(duì)圖像進(jìn)行縮放
        image_new = transform.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}
# ToTensor
## 變換ToTensor
class ToTensor(object):
    def __call__(self, sample):
        image = sample['image']
        image_new = np.transpose(image, (2, 0, 1))
        return {'image': torch.from_numpy(image_new), 'label': sample['label']}
"""
第四步:對(duì)整個(gè)數(shù)據(jù)集應(yīng)用變換
細(xì)節(jié):transformers.Compose()將不同的幾個(gè)組合起來。先進(jìn)行Resize,再進(jìn)行ToTensor
"""
# 對(duì)原始的訓(xùn)練數(shù)據(jù)集進(jìn)行變換
transformed_trainset = MyDataset(root_dir='./data/train',
                                 names_file='./data/train/train.txt',
                                 transform=transforms.Compose([
                                     Resize((224, 224)),
                                     ToTensor()]))
"""
第五步:使用DataLoader進(jìn)行包裝
為何要使用DataLoader?
1、深度學(xué)習(xí)的輸入是mini_batch形式
2、樣本加載時(shí)候可能需要隨機(jī)打亂順序,shuffle操作
3、樣本加載需要采用多線程
pytorch提供的DataLoader封裝了上述的功能,這樣使用起來更方便。
"""
# 使用DataLoader可以利用多線程,batch,shuffle等
# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
                                 batch_size=4,
                                 shuffle=True,
                                 num_workers=4)
# 可視化
def show_images_batch(sample_batched):
    images_batch, labels_batch = \
        sample_batched['image'], sample_batched['label']
    grid = make_grid(images_batch)
    plt.imshow(grid.numpy().transpose(1, 2, 0))
# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
    show_images_batch(sample_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()
plt.show()
"""
通過DataLoader包裝之后,樣本以min_batch形式輸出,而且進(jìn)行了隨機(jī)打亂順序。
至此,自定義數(shù)據(jù)集的完整流程已經(jīng)實(shí)現(xiàn),test, val集只需要改路徑即可。
"""

輸出類似:

在這里插入圖片描述


在這里插入圖片描述
在這里插入圖片描述

在這里插入圖片描述

補(bǔ)充:
更簡(jiǎn)單的方法

上述繼承Dataset,重寫__len()__,__getitem()是通用的方法,過程相對(duì)繁瑣。對(duì)于簡(jiǎn)單的分類數(shù)據(jù)集,pytorch中提供了更簡(jiǎn)便的方式----ImageFolder。

如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。仍然以cat, dog二分類數(shù)據(jù)集為例:
文件結(jié)構(gòu):

在這里插入圖片描述

在這里插入圖片描述

在這里插入圖片描述

Code

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# data_transform = transforms.Compose([
#  transforms.RandomResizedCrop(224),
#  transforms.RandomHorizontalFlip(),
#  transforms.ToTensor(),
#  transforms.Normalize(mean=[0.485, 0.456, 0.406],
#       std=[0.229, 0.224, 0.225])
# ])
data_transform = transforms.Compose([
 transforms.Resize((224,224)),
 transforms.RandomHorizontalFlip(),
 transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4)
def show_batch_images(sample_batch):
 labels_batch = sample_batch[1]
 images_batch = sample_batch[0]
 for i in range(4):
  label_ = labels_batch[i].item()
  image_ = np.transpose(images_batch[i], (1, 2, 0))
  ax = plt.subplot(1, 4, i + 1)
  ax.imshow(image_)
  ax.set_title(str(label_))
  ax.axis('off')
  plt.pause(0.01)
plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
 show_batch_images(sample_batch)
 plt.show()

由于 train 目錄下只有2個(gè)文件夾,分別為cat, dog, 因此ImageFolder安裝順序?qū)at使用標(biāo)簽0, dog使用標(biāo)簽1。(輸出類似:)

在這里插入圖片描述

參考文章

https://www.cnblogs.com/picassooo/p/12846617.html

http://www.dbjr.com.cn/article/199360.htm

到此這篇關(guān)于pytorch自定義數(shù)據(jù)集的文章就介紹到這了,更多相關(guān)pytorch自定義數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python中pycurl庫(kù)的用法實(shí)例

    python中pycurl庫(kù)的用法實(shí)例

    這篇文章主要介紹了python中pycurl庫(kù)的用法實(shí)例,可實(shí)現(xiàn)從指定網(wǎng)址讀取網(wǎng)頁(yè)的功能,需要的朋友可以參考下
    2014-09-09
  • python標(biāo)記語句塊使用方法總結(jié)

    python標(biāo)記語句塊使用方法總結(jié)

    在本篇文章里小編給大家整理了關(guān)于python標(biāo)記語句塊使用方法以及相關(guān)知識(shí)點(diǎn),需要的朋友們參考下。
    2019-08-08
  • Python Matplotlib繪制箱線圖的全過程

    Python Matplotlib繪制箱線圖的全過程

    又稱箱形圖(boxplot)或盒式圖,數(shù)據(jù)大小、占比、趨勢(shì)等等的呈現(xiàn)其包含一些統(tǒng)計(jì)學(xué)的均值、分位數(shù)、極值等等統(tǒng)計(jì)量,因此該圖信息量較大,下面這篇文章主要給大家介紹了關(guān)于Python Matplotlib繪制箱線圖的相關(guān)資料,需要的朋友可以參考下
    2021-09-09
  • 一篇文章帶你了解python異?;A(chǔ)

    一篇文章帶你了解python異?;A(chǔ)

    今天小編就為大家分享一篇關(guān)于Python中的異常介紹,小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧
    2021-08-08
  • 解決Python 遍歷字典時(shí)刪除元素報(bào)異常的問題

    解決Python 遍歷字典時(shí)刪除元素報(bào)異常的問題

    下面小編就為大家?guī)硪黄鉀QPython 遍歷字典時(shí)刪除元素報(bào)異常的問題。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2016-09-09
  • 使用Python獲取當(dāng)前工作目錄和執(zhí)行命令的位置

    使用Python獲取當(dāng)前工作目錄和執(zhí)行命令的位置

    這篇文章主要介紹了使用Python獲取當(dāng)前工作目錄和執(zhí)行命令的位置,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-03-03
  • 基于Python-turtle庫(kù)繪制路飛的草帽骷髏旗、美國(guó)隊(duì)長(zhǎng)的盾牌、高達(dá)的源碼

    基于Python-turtle庫(kù)繪制路飛的草帽骷髏旗、美國(guó)隊(duì)長(zhǎng)的盾牌、高達(dá)的源碼

    這篇文章主要介紹了基于Python-turtle庫(kù)繪制路飛的草帽骷髏旗、美國(guó)隊(duì)長(zhǎng)的盾牌、高達(dá)的源碼,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-02-02
  • Pytorch Tensor的索引與切片例子

    Pytorch Tensor的索引與切片例子

    今天小編就為大家分享一篇Pytorch Tensor的索引與切片例子,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2019-08-08
  • python中subplot大小的設(shè)置步驟

    python中subplot大小的設(shè)置步驟

    matploglib能夠繪制出精美的圖表,有時(shí)候我們希望把一組圖放在一起進(jìn)行比較,就需要用到matplotlib中提供的subplot了,這篇文章主要給大家介紹了關(guān)于python中subplot大小的設(shè)置方法,需要的朋友可以參考下
    2021-06-06
  • python爬蟲 urllib模塊反爬蟲機(jī)制UA詳解

    python爬蟲 urllib模塊反爬蟲機(jī)制UA詳解

    這篇文章主要介紹了python爬蟲 urllib模塊反爬蟲機(jī)制UA詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08

最新評(píng)論