欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch加載數(shù)據(jù)集的方式總結(jié)及補(bǔ)充

 更新時(shí)間:2022年11月18日 08:57:14   作者:咕嚕咕嚕冰闊落  
Pytorch自定義數(shù)據(jù)集方法,應(yīng)該是用pytorch做算法的最基本的東西,下面這篇文章主要給大家介紹了關(guān)于Pytorch加載數(shù)據(jù)集的方式總結(jié)及補(bǔ)充,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下

前言

在用Pytorch加載數(shù)據(jù)集時(shí),看GitHub上的代碼經(jīng)常會(huì)用到ImageFolder、DataLoader等一系列方法,而這些方法又是來自于torchvision、torch.utils.data。除加載數(shù)據(jù)集外,還有torchvision中的transforms對數(shù)據(jù)集預(yù)處理…等等等等。這個(gè)data,那個(gè)dataset…這一系列下來,不加注意的話實(shí)在有點(diǎn)打腦殼??磩e人的代碼加載數(shù)據(jù)集挺簡單,但是自己用的時(shí)候,尤其是加載自己所制作的數(shù)據(jù)集的時(shí)候,就會(huì)茫然無措。別無他法,抱著硬啃的心態(tài),查閱了其他博文,通過代碼實(shí)驗(yàn),終于是理清楚了思路。

Pytorch加載數(shù)據(jù)集可以分兩種大的情況:一、自己重寫定義; 二、用Pytorch自帶的類。第二種里面又有多種不同的方法(datasets、 ImageFolder等),但這些方法都有相同的處理規(guī)律。我理解的,無論是哪種情況,加載數(shù)據(jù)集都需要構(gòu)造數(shù)據(jù)加載器數(shù)據(jù)裝載器(后者生成的是可迭代的數(shù)據(jù))。現(xiàn)將這兩種情況一一說明。

一、自己重寫定義(Dataset、DataLoader)

目前我們有自己制作的數(shù)據(jù)以及數(shù)據(jù)標(biāo)簽,但是有時(shí)候感覺不太適合直接用Pytorch自帶加載數(shù)據(jù)集的方法。我們可以自己來重寫定義一個(gè)類,這個(gè)類繼承于 torch.utils.data.Dataset,同時(shí)我們需要重寫這個(gè)類里面的兩個(gè)方法 _ getitem__ () 和__ len()__函數(shù)。

如下所示。這兩種方法如何構(gòu)造以及具體的細(xì)節(jié)可以查看其他的博客。len方法必須返回?cái)?shù)據(jù)的長度,getitem方法必須返回?cái)?shù)據(jù)以及標(biāo)簽。

import torch
import numpy as np

# 定義GetLoader類,繼承Dataset方法,并重寫__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函數(shù),得到數(shù)據(jù)
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根據(jù)batchsize劃分?jǐn)?shù)據(jù)后得到的索引,最后將data和對應(yīng)的labels進(jìn)行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 該函數(shù)返回?cái)?shù)據(jù)大小長度,目的是DataLoader方便劃分,如果不知道大小,DataLoader會(huì)一臉懵逼
    def __len__(self):
        return len(self.data)

# 隨機(jī)生成數(shù)據(jù),大小為10 * 20列
source_data = np.random.rand(10, 20)
# 隨機(jī)生成標(biāo)簽,大小為10 * 1列
source_label = np.random.randint(0,2,(10, 1))
# 通過GetLoader將數(shù)據(jù)進(jìn)行加載,返回Dataset對象,包含data和labels
torch_data = GetLoader(source_data, source_label)

通過上述的程序,我們構(gòu)造了一個(gè)數(shù)據(jù)加載器torch_data,但是還是不能直接傳入網(wǎng)絡(luò)中。接下來需要構(gòu)造數(shù)據(jù)裝載器,產(chǎn)生可迭代的數(shù)據(jù),再傳入網(wǎng)絡(luò)中。DataLoader類完成這個(gè)工作。

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)

參數(shù)解釋:

1.dataset     : 加載torch.utils.data.Dataset對象數(shù)據(jù)
2.batch_size  : 每個(gè)batch的大小,將我們的數(shù)據(jù)分批輸入到網(wǎng)絡(luò)中
3.shuffle     : 是否對數(shù)據(jù)進(jìn)行打亂
4.drop_last   : 是否對無法整除的最后一個(gè)datasize進(jìn)行丟棄
5.num_workers : 表示加載的時(shí)候子進(jìn)程數(shù)

結(jié)合我們自己定義的加載數(shù)據(jù)集類,可以如下使用。后面將data和label傳入我們定義的模型中。

...
torch_data = GetLoader(source_data, source_label)

from torch.utils.data import DataLoader
datas = DataLoader(torch_data, batch_size = 4, shuffle = True, drop_last = False, num_workers = 2)
for i, (data, label) in enumerate(datas):
	# i表示第幾個(gè)batch, data表示batch_size個(gè)原始的數(shù)據(jù),label代表batch_size個(gè)數(shù)據(jù)的標(biāo)簽
    print("第 {} 個(gè)Batch \n{}".format(i, data))

二、用Pytorch自帶的類(ImageFolder、datasets、DataLoader)

2.1 加載自己的數(shù)據(jù)集

2.1.1 ImageFolder介紹

和第一種情況不一樣,我們不需要在代碼上自己定義數(shù)據(jù)集類了,而是將數(shù)據(jù)集按照一定的格式擺放,調(diào)用ImageFolder類即可。這種是在調(diào)用Pytorch內(nèi)部的API,所以我們自己的數(shù)據(jù)集得需要按照API內(nèi)部所規(guī)定的存放格式。torchvision.datasets.ImageFolder 要求數(shù)據(jù)集按照如下方式組織。根目錄 root 下存儲(chǔ)的是類別文件夾(如cat,dog),每個(gè)類別文件夾下存儲(chǔ)相應(yīng)類別的圖像(如xxx.png)

A generic data loader where the images are arranged in this way:

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

torchvision.datasets.ImageFolder有以下參數(shù):

dataset=torchvision.datasets.ImageFolder(
                       root, transform=None, 
                       target_transform=None, 
                       loader=<function default_loader>, 
                       is_valid_file=None)

參數(shù)解釋:

1.root:根目錄,在root目錄下,應(yīng)該有不同類別的子文件夾;
    |--data(root)
        |--train
            |--cat
            |--dog
        |--valid
            |--cat
            |--dog        
2.transform:對圖片進(jìn)行預(yù)處理的操作,原始圖像作為一個(gè)輸入,返回的是transform變換后的圖片;
3.target_transform:對圖片類別進(jìn)行預(yù)處理的操作,輸入為 target,輸出對其的轉(zhuǎn)換。 如果不傳該參數(shù),即對target不做任何轉(zhuǎn)換,返回的順序索引 0,1, 2…
4.loader:表示數(shù)據(jù)集加載方式,通常默認(rèn)加載方式即可;
5.is_valid_file:獲取圖像文件的路徑并檢查該文件是否為有效文件的函數(shù)(用于檢查損壞文件)

作為torchvision.datasets.ImageFolder的返回,會(huì)有以下三種屬性:

(1)self.classes:用一個(gè) list 保存類別名稱

(2)self.class_to_idx:類別對應(yīng)的索引,與不做任何轉(zhuǎn)換返回的 target 對應(yīng)

(3)self.imgs:保存(img_path, class) tuple的list

以貓狗類別舉例,各屬性輸出如下所示:

print(dataset.classes)  #根據(jù)分的文件夾的名字來確定的類別
print(dataset.class_to_idx) #按順序?yàn)檫@些類別定義索引為0,1...
print(dataset.imgs) #返回從所有文件夾中得到的圖片的路徑以及其類別
'''
輸出:
['cat', 'dog']
{'cat': 0, 'dog': 1}
[('./data/train\\cat\\1.jpg', 0), 
 ('./data/train\\cat\\2.jpg', 0), 
 ('./data/train\\dog\\1.jpg', 1), 
 ('./data/train\\dog\\2.jpg', 1)]
'''

2.2.2 ImageFolder加載數(shù)據(jù)集完整例子

# 5. 將文件夾數(shù)據(jù)導(dǎo)入
train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size = batch_size, shuffle=True,
                                           num_workers = 2)
# 6. 傳入網(wǎng)絡(luò)進(jìn)行訓(xùn)練
for epoch in range(epochs):
    train_bar = tqdm(train_loader, file = sys.stdout)
    for step, data in enumerate(train_bar):
    ...

和第一種情況自己重寫定義一樣,上述的代碼僅僅完成了數(shù)據(jù)加載器的定義。這樣是不能直接傳入網(wǎng)絡(luò)中進(jìn)行訓(xùn)練的,需要再構(gòu)造一個(gè)可迭代的數(shù)據(jù)裝載器。DataLoader類的使用方式上文中有詳細(xì)介紹。

# 5. 將文件夾數(shù)據(jù)導(dǎo)入
train_loader = torch.utils.data.DataLoader(dataset,
                                           batch_size = batch_size, shuffle=True,
                                           num_workers = 2)
# 6. 傳入網(wǎng)絡(luò)進(jìn)行訓(xùn)練
for epoch in range(epochs):
    train_bar = tqdm(train_loader, file = sys.stdout)
    for step, data in enumerate(train_bar):
    ...

2.2 加載常見的數(shù)據(jù)集

有些數(shù)據(jù)集是公共的,比如常見的MNIST,CIFAR10,SVHN等等。這些數(shù)據(jù)集在Pytorch中可以通過代碼就可以下載、加載。如下代碼所示。用torchvision中的datasets類下載數(shù)據(jù)集,并還是結(jié)合DataLoader來構(gòu)建可直接傳入網(wǎng)絡(luò)的數(shù)據(jù)裝載器。

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def dataloader(dataset, input_size, batch_size, split='train'):
    transform = transforms.Compose([
        					transforms.Resize((input_size, input_size)), 
       					    transforms.ToTensor(), 
        					transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    if dataset == 'mnist':
        data_loader = DataLoader(
            datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'fashion-mnist':
        data_loader = DataLoader(
            datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'cifar10':
        data_loader = DataLoader(
            datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'svhn':
        data_loader = DataLoader(
            datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'stl10':
        data_loader = DataLoader(
            datasets.STL10('data/stl10', split=split, download=True, transform=transform),
            batch_size=batch_size, shuffle=True)
    elif dataset == 'lsun-bed':
        data_loader = DataLoader(
            datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
            batch_size=batch_size, shuffle=True)

    return data_loader

三、總結(jié)

至于覺得加載數(shù)據(jù)集比較難的很大的原因,我感覺是Dataset、datasets、DataLoader以及torch.utils.data、torchvision種類太多,有點(diǎn)混亂。上面的梳理,我的理解是無論是哪種方式,終端還是需要DataLoader整合。作為加載數(shù)據(jù)集的前端,用自己定義的、用ImageFolder的、還是用datasets加載常用數(shù)據(jù)集,都是在構(gòu)造數(shù)據(jù)加載器,而且構(gòu)造起來也并不復(fù)雜。梳理清晰后,相信對Pytorch加載數(shù)據(jù)集有了更進(jìn)一步的理解。

四、transforms變換講解

torchvision.transforms是Pytorch中的圖像預(yù)處理包。一般定義在加載數(shù)據(jù)集之前,用transforms中的Compose類把多個(gè)步驟整合到一起,而這些步驟是transforms中的函數(shù)。

transforms中的函數(shù)有這些:

函數(shù)含義
transforms.Resize把給定的圖片resize到given size
transforms.Normalize用均值和標(biāo)準(zhǔn)差歸一化張量圖像
transforms.Totensor可以將PIL和numpy格式的數(shù)據(jù)從[0,255]范圍轉(zhuǎn)換到[0,1] ; <br /另外原始數(shù)據(jù)的shape是(H x W x C),通過transforms.ToTensor()后shape會(huì)變?yōu)椋– x H x W)
transforms.RandomGrayscale將圖像以一定的概率轉(zhuǎn)換為灰度圖像
transforms.ColorJitter隨機(jī)改變圖像的亮度對比度和飽和度
transforms.Centercrop在圖片的中間區(qū)域進(jìn)行裁剪
transforms.RandomCrop在一個(gè)隨機(jī)的位置進(jìn)行裁剪
transforms.FiceCrop把圖像裁剪為四個(gè)角和一個(gè)中心
transforms.RandomResizedCrop將PIL圖像裁剪成任意大小和縱橫比
transforms.ToPILImageconvert a tensor to PIL image
transforms.RandomHorizontalFlip以0.5的概率水平翻轉(zhuǎn)給定的PIL圖像
transforms.RandomVerticalFlip以0.5的概率豎直翻轉(zhuǎn)給定的PIL圖像
transforms.Grayscale將圖像轉(zhuǎn)換為灰度圖像

不同函數(shù)對應(yīng)有不同的屬性,用transforms.Compose將不同的操作整合在一起,如下所示。

transforms.Compose([transforms.RandomResizedCrop(224),
 		    	   transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

五、DataLoader的補(bǔ)充

數(shù)據(jù)加載器,結(jié)合了數(shù)據(jù)集和取樣器,并且可以提供多個(gè)線程處理數(shù)據(jù)集。

在訓(xùn)練模型時(shí)使用到此函數(shù),用來把訓(xùn)練數(shù)據(jù)分成多個(gè)小組,此函數(shù)每次拋出一組數(shù)據(jù)。直至把所有的數(shù)據(jù)都拋出。就是做一個(gè)數(shù)據(jù)的初始化。

用下面的例子測試:

"""
    批訓(xùn)練,把數(shù)據(jù)變成一小批一小批數(shù)據(jù)進(jìn)行訓(xùn)練。
    DataLoader就是用來包裝所使用的數(shù)據(jù),每次拋出一批數(shù)據(jù)
"""
import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把數(shù)據(jù)放在數(shù)據(jù)庫中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 從數(shù)據(jù)庫中每次抽出batch size個(gè)樣本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)

def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

if __name__ == '__main__':
    show_batch()

結(jié)果如下所示。仔細(xì)觀察:

每一個(gè)step,batch_x是不會(huì)重合的,batch_y里面的值也是不會(huì)重合的(第一個(gè)step中,batch_x:tensor([ 3., 10., 6., 2., 8.]);第二個(gè)step中batch_x:tensor([5., 9., 7., 4., 1.])),說明DataLoader將數(shù)據(jù)打亂后,每次選用其中的Batch_size個(gè)數(shù)據(jù)且不會(huì)重復(fù);

其二,batch_x 和 batch_y對應(yīng)的索引之和相等,這說明DataLoader對圖像和標(biāo)簽打亂順序時(shí),同時(shí)按照某一規(guī)律打亂,并不會(huì)造成標(biāo)簽和圖像出現(xiàn)不對應(yīng)的情況。

其三,在不同的epoch之間,每次數(shù)據(jù)也是不同的,說明DataLoader每次被調(diào)用時(shí),都會(huì)重新打亂一次。

steop:0, batch_x:tensor([ 3., 10.,  6.,  2.,  8.]), batch_y:tensor([8., 1., 5., 9., 3.])
steop:1, batch_x:tensor([5., 9., 7., 4., 1.]), batch_y:tensor([ 6.,  2.,  4.,  7., 10.])
steop:0, batch_x:tensor([8., 3., 1., 2., 9.]), batch_y:tensor([ 3.,  8., 10.,  9.,  2.])
steop:1, batch_x:tensor([10., 5.,  4.,  7.,  6.]), batch_y:tensor([1., 6., 7., 4., 5.])
steop:0, batch_x:tensor([5., 8., 4., 3., 7.]), batch_y:tensor([6., 3., 7., 8., 4.])
steop:1, batch_x:tensor([ 2., 10.,  6.,  9.,  1.]), batch_y:tensor([ 9.,  1.,  5.,  2., 10.])

總結(jié)

到此這篇關(guān)于Pytorch加載數(shù)據(jù)集的方式總結(jié)及補(bǔ)充的文章就介紹到這了,更多相關(guān)Pytorch加載數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python網(wǎng)絡(luò)編程之?dāng)?shù)據(jù)傳輸U(kuò)DP實(shí)例分析

    python網(wǎng)絡(luò)編程之?dāng)?shù)據(jù)傳輸U(kuò)DP實(shí)例分析

    這篇文章主要介紹了python網(wǎng)絡(luò)編程之?dāng)?shù)據(jù)傳輸U(kuò)DP實(shí)現(xiàn)方法,實(shí)例分析了Python基于UDP協(xié)議的數(shù)據(jù)傳輸實(shí)現(xiàn)方法,需要的朋友可以參考下
    2015-05-05
  • python中reader的next用法

    python中reader的next用法

    這篇文章主要介紹了python中reader的next用法,分別介紹了python3中的用法和python2中的用法,具體實(shí)例代碼大家參考下本文
    2018-07-07
  • Python破解BiliBili滑塊驗(yàn)證碼的思路詳解(完美避開人機(jī)識別)

    Python破解BiliBili滑塊驗(yàn)證碼的思路詳解(完美避開人機(jī)識別)

    這篇文章主要介紹了Python破解BiliBili滑塊驗(yàn)證碼的思路,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-02-02
  • python之singledispatch單分派問題

    python之singledispatch單分派問題

    這篇文章主要介紹了python之singledispatch單分派問題,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • Python的語言類型(詳解)

    Python的語言類型(詳解)

    下面小編就為大家?guī)硪黄狿ython的語言類型(詳解)。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2017-06-06
  • Python中的遠(yuǎn)程調(diào)試與性能優(yōu)化技巧分享

    Python中的遠(yuǎn)程調(diào)試與性能優(yōu)化技巧分享

    Python 是一種簡單易學(xué)、功能強(qiáng)大的編程語言,廣泛應(yīng)用于各種領(lǐng)域,包括網(wǎng)絡(luò)編程、數(shù)據(jù)分析、人工智能等,在開發(fā)過程中,我們經(jīng)常會(huì)遇到需要遠(yuǎn)程調(diào)試和性能優(yōu)化的情況,本文將介紹如何利用遠(yuǎn)程調(diào)試工具和性能優(yōu)化技巧來提高 Python 應(yīng)用程序的效率和性能
    2024-05-05
  • python中g(shù)etopt()函數(shù)用法詳解

    python中g(shù)etopt()函數(shù)用法詳解

    這篇文章主要介紹了python中g(shù)etopt()函數(shù)用法,通過getopt模塊中的getopt(?)方法,我們可以獲取和解析命令行傳入的參數(shù),需要的朋友可以參考下
    2022-12-12
  • python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲的示例代碼(可以自定義子彈速度)

    python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲的示例代碼(可以自定義子彈速度)

    這篇文章主要介紹了python+pygame實(shí)現(xiàn)坦克大戰(zhàn)小游戲---可以自定義子彈速度,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-08-08
  • python模塊之paramiko實(shí)例代碼

    python模塊之paramiko實(shí)例代碼

    這篇文章主要介紹了python模塊之paramiko,分享了相關(guān)代碼示例,小編覺得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • Python?Logistic邏輯回歸算法使用詳解

    Python?Logistic邏輯回歸算法使用詳解

    這篇文章主要介紹了Python?Logistic邏輯回歸算法使用的方法和原理,Logistic雖然不是十大經(jīng)典算法之一,但卻是數(shù)據(jù)挖掘中常用的有力算法,所以這里也專門進(jìn)行了學(xué)習(xí),需要的朋友可以參考下
    2021-06-06

最新評論