欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch學(xué)習(xí)教程之自定義數(shù)據(jù)集

 更新時(shí)間:2020年11月10日 11:52:40   作者:俠之大者_(dá)7d3f  
這篇文章主要給大家介紹了關(guān)于pytorch學(xué)習(xí)教程之自定義數(shù)據(jù)集的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧

自定義數(shù)據(jù)集

在訓(xùn)練深度學(xué)習(xí)模型之前,樣本集的制作非常重要。在pytorch中,提供了一些接口和類,方便我們定義自己的數(shù)據(jù)集合,下面完整的試驗(yàn)自定義樣本集的整個(gè)流程。

開發(fā)環(huán)境

  • Ubuntu 18.04
  • pytorch 1.0
  • pycharm

實(shí)驗(yàn)?zāi)康?/strong>

  1. 掌握pytorch中數(shù)據(jù)集相關(guān)的API接口和類
  2. 熟悉數(shù)據(jù)集制作的整個(gè)流程

實(shí)驗(yàn)過程

1.收集圖像樣本

以簡(jiǎn)單的貓狗二分類為例,可以在網(wǎng)上下載一些貓狗圖片。創(chuàng)建以下目錄:

  • data-------------根目錄
  • data/test-------測(cè)試集
  • data/train------訓(xùn)練集
  • data/val--------驗(yàn)證集

在test/train/val之下在校分別創(chuàng)建2個(gè)文件夾,dog, cat

cat, dog文件夾下分別存放2類圖像:

標(biāo)簽

種類 標(biāo)簽
cat 0
dog 1

之后寫一個(gè)簡(jiǎn)單的python腳本,生成txt文件,用于指明每個(gè)圖像和標(biāo)簽的對(duì)應(yīng)關(guān)系。

格式: /cat/1.jpg 0 \n dog/1.jpg 1 \n .....

如圖:

至此,樣本集的收集以及簡(jiǎn)單歸類完成,下面將開始采用pytorch的數(shù)據(jù)集相關(guān)API和類。

2. 使用pytorch相關(guān)類,API對(duì)數(shù)據(jù)集進(jìn)行封裝

2.1 pytorch中數(shù)據(jù)集相關(guān)的類,接口

pytorch中數(shù)據(jù)集相關(guān)的類位于torch.utils.data package中。

https://pytorch.org/docs/stable/data.html

本次實(shí)驗(yàn),主要使用以下類:

torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset類的使用: 所有的類都應(yīng)該是此類的子類(也就是說應(yīng)該繼承該類)。 所有的子類都要重寫(override) __len()__, __getitem()__ 這兩個(gè)方法。

方法 作用
__len()__ 此方法應(yīng)該提供數(shù)據(jù)集的大小(容量)
__getitem()__ 此方法應(yīng)該提供支持下標(biāo)索方式引訪問數(shù)據(jù)集

這里和Java抽象類很相似,在抽象類abstract class中,一般會(huì)定義一些抽象方法abstract method,抽象方法:只有方法名沒有方法的具體實(shí)現(xiàn)。如果一個(gè)子類繼承于該抽象類,要重寫(overrode)父類的抽象方法。

DataLoader類的使用:

2.2 實(shí)現(xiàn)

使用到的python package

python package 目的
numpy 矩陣操作,對(duì)圖像進(jìn)行轉(zhuǎn)置
skimage 圖像處理,圖像I/O,圖像變換
matplotlib 圖像的顯示,可視化
os 一些文件查找操作
torch pytorch
torvision pytorch

源碼

導(dǎo)入python包

import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid

第一步:

定義一個(gè)子類,繼承Dataset類, 重寫 __len()__, __getitem()__ 方法。

細(xì)節(jié):

1.數(shù)據(jù)集中一個(gè)一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。

2.圖像的讀取:采用skimage.io進(jìn)行讀取,讀取之后的結(jié)果為numpy.ndarray形式。

3.圖像變換:transform參數(shù)

# step1: 定義MyDataset類, 繼承Dataset, 重寫抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):

 def __init__(self, root_dir, names_file, transform=None):
 self.root_dir = root_dir
 self.names_file = names_file
 self.transform = transform
 self.size = 0
 self.names_list = []

 if not os.path.isfile(self.names_file):
  print(self.names_file + 'does not exist!')
 file = open(self.names_file)
 for f in file:
  self.names_list.append(f)
  self.size += 1

 def __len__(self):
 return self.size

 def __getitem__(self, idx):
 image_path = self.root_dir + self.names_list[idx].split(' ')[0]
 if not os.path.isfile(image_path):
  print(image_path + 'does not exist!')
  return None
 image = io.imread(image_path) # use skitimage
 label = int(self.names_list[idx].split(' ')[1])

 sample = {'image': image, 'label': label}
 if self.transform:
  sample = self.transform(sample)

 return sample

第二步

實(shí)例化一個(gè)對(duì)象,并讀取和顯示數(shù)據(jù)集

train_dataset = MyDataset(root_dir='./data/train',
    names_file='./data/train/train.txt',
    transform=None)

plt.figure()
for (cnt,i) in enumerate(train_dataset):
 image = i['image']
 label = i['label']

 ax = plt.subplot(4, 4, cnt+1)
 ax.axis('off')
 ax.imshow(image)
 ax.set_title('label {}'.format(label))
 plt.pause(0.001)

 if cnt == 15:
 break

只顯示了部分?jǐn)?shù)據(jù),前部分全是cat

第三步(可選 optional)

對(duì)數(shù)據(jù)集進(jìn)行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數(shù)據(jù)歸一化。另一方面,可以通過變換進(jìn)行數(shù)據(jù)增加data argument

關(guān)于pytorch中的變換transforms,請(qǐng)參考該系列之前的文章

由于數(shù)據(jù)集中樣本采用字典dicts形式表示。 因此不能直接調(diào)用torchvision.transofrms中的方法。

本實(shí)驗(yàn)只進(jìn)行尺寸歸一化Resize, 數(shù)據(jù)類型變換ToTensor操作。

Resize

# # 變換Resize
class Resize(object):

 def __init__(self, output_size: tuple):
 self.output_size = output_size

 def __call__(self, sample):
 # 圖像
 image = sample['image']
 # 使用skitimage.transform對(duì)圖像進(jìn)行縮放
 image_new = transform.resize(image, self.output_size)
 return {'image': image_new, 'label': sample['label']}

ToTensor

# # 變換ToTensor
class ToTensor(object):

 def __call__(self, sample):
 image = sample['image']
 image_new = np.transpose(image, (2, 0, 1))
 return {'image': torch.from_numpy(image_new),
  'label': sample['label']}

第四步: 對(duì)整個(gè)數(shù)據(jù)集應(yīng)用變換

細(xì)節(jié): transformers.Compose() 將不同的幾個(gè)組合起來。先進(jìn)行Resize, 再進(jìn)行ToTensor

# 對(duì)原始的訓(xùn)練數(shù)據(jù)集進(jìn)行變換
transformed_trainset = MyDataset(root_dir='./data/train',
    names_file='./data/train/train.txt',
    transform=transforms.Compose(
    [Resize((224,224)),
    ToTensor()]
    ))

第五步: 使用DataLoader進(jìn)行包裝

為何要使用DataLoader?

① 深度學(xué)習(xí)的輸入是mini_batch形式

② 樣本加載時(shí)候可能需要隨機(jī)打亂順序,shuffle操作

③ 樣本加載需要采用多線程

pytorch提供的DataLoader封裝了上述的功能,這樣使用起來更方便。

# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
     batch_size=4,
     shuffle=True,
     num_workers=4)

可視化:

def show_images_batch(sample_batched):
 images_batch, labels_batch = \
 sample_batched['image'], sample_batched['label']
 grid = make_grid(images_batch)
 plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
 show_images_batch(sample_batch)
 plt.axis('off')
 plt.ioff()
 plt.show()


plt.show()

通過DataLoader包裝之后,樣本以min_batch形式輸出,而且進(jìn)行了隨機(jī)打亂順序。

至此,自定義數(shù)據(jù)集的完整流程已實(shí)現(xiàn),test, val集只需要改路徑即可。

補(bǔ)充

更簡(jiǎn)單的方法

上述繼承Dataset, 重寫 __len()__, __getitem() 是通用的方法,過程相對(duì)繁瑣。對(duì)于簡(jiǎn)單的分類數(shù)據(jù)集,pytorch中提供了更簡(jiǎn)便的方式——ImageFolder。

如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。

仍然以cat, dog 二分類數(shù)據(jù)集為例:

文件結(jié)構(gòu):



Code

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np


# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

# data_transform = transforms.Compose([
#  transforms.RandomResizedCrop(224),
#  transforms.RandomHorizontalFlip(),
#  transforms.ToTensor(),
#  transforms.Normalize(mean=[0.485, 0.456, 0.406],
#       std=[0.229, 0.224, 0.225])
# ])

data_transform = transforms.Compose([
 transforms.Resize((224,224)),
 transforms.RandomHorizontalFlip(),
 transforms.ToTensor(),

])

train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4)


def show_batch_images(sample_batch):
 labels_batch = sample_batch[1]
 images_batch = sample_batch[0]

 for i in range(4):
  label_ = labels_batch[i].item()
  image_ = np.transpose(images_batch[i], (1, 2, 0))
  ax = plt.subplot(1, 4, i + 1)
  ax.imshow(image_)
  ax.set_title(str(label_))
  ax.axis('off')
  plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
 show_batch_images(sample_batch)

 plt.show()

由于 train 目錄下只有2個(gè)文件夾,分別為cat, dog, 因此ImageFolder安裝順序?qū)at使用標(biāo)簽0, dog使用標(biāo)簽1。

End

參考:

https://pytorch.org/docs/stable/data.html

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

到此這篇關(guān)于pytorch學(xué)習(xí)教程之自定義數(shù)據(jù)集的文章就介紹到這了,更多相關(guān)pytorch自定義數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Python中文件路徑常用操作總結(jié)

    Python中文件路徑常用操作總結(jié)

    這篇文章主要為大家詳細(xì)介紹了Python中文件路徑常用操作的相關(guān)知識(shí),文中的示例代碼講解詳細(xì),具有一定的借鑒價(jià)值,感興趣的小伙伴可以學(xué)習(xí)一下
    2023-11-11
  • numpy中的nan和inf,及其批量判別、替換方式

    numpy中的nan和inf,及其批量判別、替換方式

    在Numpy中,NaN表示非數(shù)值,Inf表示無窮大,NaN與任何值計(jì)算都是NaN,Inf與0相乘是NaN,其余情況下與Inf運(yùn)算仍為Inf,可以使用np.isnan(), np.isinf(), np.isneginf(), np.isposinf(), np.isfinite()等函數(shù)進(jìn)行批量判別,返回布爾值數(shù)組
    2024-09-09
  • TensorFlow基于MNIST數(shù)據(jù)集實(shí)現(xiàn)車牌識(shí)別(初步演示版)

    TensorFlow基于MNIST數(shù)據(jù)集實(shí)現(xiàn)車牌識(shí)別(初步演示版)

    這篇文章主要介紹了TensorFlow基于MNIST數(shù)據(jù)集實(shí)現(xiàn)車牌識(shí)別(初步演示版),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-08-08
  • python日期相關(guān)操作實(shí)例小結(jié)

    python日期相關(guān)操作實(shí)例小結(jié)

    這篇文章主要介紹了python日期相關(guān)操作,結(jié)合實(shí)例形式總結(jié)分析了Python針對(duì)日期時(shí)間的獲取、轉(zhuǎn)換、運(yùn)算等相關(guān)操作技巧,需要的朋友可以參考下
    2019-06-06
  • Python遍歷zip文件輸出名稱時(shí)出現(xiàn)亂碼問題的解決方法

    Python遍歷zip文件輸出名稱時(shí)出現(xiàn)亂碼問題的解決方法

    這篇文章主要介紹了Python遍歷zip文件輸出名稱時(shí)出現(xiàn)亂碼問題的解決方法,實(shí)例分析了Python亂碼的出現(xiàn)的原因與相應(yīng)的解決方法,需要的朋友可以參考下
    2015-04-04
  • Python?Flask?上傳文件測(cè)試示例

    Python?Flask?上傳文件測(cè)試示例

    這篇文章主要為大家介紹了Python?Flask?上傳文件測(cè)試的方法示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-07-07
  • Python常用爬蟲代碼總結(jié)方便查詢

    Python常用爬蟲代碼總結(jié)方便查詢

    今天小編就為大家分享一篇關(guān)于Python常用爬蟲代碼總結(jié)方便查詢,小編覺得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來看看吧
    2019-02-02
  • python中實(shí)現(xiàn)json數(shù)據(jù)和類對(duì)象相互轉(zhuǎn)化的四種方式

    python中實(shí)現(xiàn)json數(shù)據(jù)和類對(duì)象相互轉(zhuǎn)化的四種方式

    在日常的軟件測(cè)試過程中,測(cè)試數(shù)據(jù)的構(gòu)造是一個(gè)占比非常大的活動(dòng),對(duì)于測(cè)試數(shù)據(jù)的構(gòu)造,分為結(jié)構(gòu)化的數(shù)據(jù)構(gòu)造方式和非結(jié)構(gòu)化的數(shù)據(jù)構(gòu)造方式,此篇文章,會(huì)通過4種方式來展示json數(shù)據(jù)與python的類對(duì)象相互轉(zhuǎn)化,需要的朋友可以參考下
    2024-07-07
  • python判斷變量是否是None的三種寫法總結(jié)

    python判斷變量是否是None的三種寫法總結(jié)

    代碼中經(jīng)常會(huì)有變量是否為None的判斷,這篇文章給大家總結(jié)了三種判斷變量是否是none的寫法,文中通過代碼示例介紹的非常詳細(xì),需要的朋友可以參考下
    2023-12-12
  • Python爬蟲獲取國(guó)外大橋排行榜數(shù)據(jù)清單

    Python爬蟲獲取國(guó)外大橋排行榜數(shù)據(jù)清單

    這篇文章主要介紹了Python爬蟲獲取國(guó)外大橋排行榜數(shù)據(jù)清單,文章通過PyQuery?解析框架展開全文詳細(xì)內(nèi)容,需要的小伙伴可以參考一下
    2022-05-05

最新評(píng)論