欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

在Pytorch中自定義dataset讀取數(shù)據(jù)的實(shí)現(xiàn)代碼

 更新時(shí)間:2023年12月20日 09:44:33   作者:Kelly_Ai_Bai  
這篇文章給大家介紹了如何在Pytorch中自定義dataset讀取數(shù)據(jù),文中給出了詳細(xì)的圖文介紹和代碼講解,對(duì)大家的學(xué)習(xí)或工作有一定的幫助,需要的朋友可以參考下

這里使用的是經(jīng)典的花分類數(shù)據(jù)集

下載地址:

https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

下載結(jié)束后進(jìn)行解壓,可以得到五種不同種類花的圖片,如上圖所示

主函數(shù) main

 
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
 
    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
 
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
 
    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])
 
    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)
 
    # plot_data_loader_image(train_loader)
 
    for step, data in enumerate(train_loader):
        images, labels = data

其中,

train_images_path, train_images_label, val_images_path, val_images_label  = read_split_data ( root )

傳入?yún)?shù) root(就是該數(shù)據(jù)集所在的路徑),沒有傳入?yún)?shù)val_rate就取其默認(rèn)值0.2( 即驗(yàn)證集占整個(gè)數(shù)據(jù)集的 20% ), 調(diào)用函數(shù) read_split_data

def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保證隨機(jī)結(jié)果可復(fù)現(xiàn)
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
 
    # 遍歷文件夾,一個(gè)文件夾對(duì)應(yīng)一個(gè)類別
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保證順序一致
    flower_class.sort()
    # 生成類別名稱以及對(duì)應(yīng)的數(shù)字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
 
    train_images_path = []  # 存儲(chǔ)訓(xùn)練集的所有圖片路徑
    train_images_label = []  # 存儲(chǔ)訓(xùn)練集圖片對(duì)應(yīng)索引信息
    val_images_path = []  # 存儲(chǔ)驗(yàn)證集的所有圖片路徑
    val_images_label = []  # 存儲(chǔ)驗(yàn)證集圖片對(duì)應(yīng)索引信息
    every_class_num = []  # 存儲(chǔ)每個(gè)類別的樣本總數(shù)
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后綴類型
    # 遍歷每個(gè)文件夾下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍歷獲取supported支持的所有文件路徑
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 獲取該類別對(duì)應(yīng)的索引
        image_class = class_indices[cla]
        # 記錄該類別的樣本數(shù)量
        every_class_num.append(len(images))
        # 按比例隨機(jī)采樣驗(yàn)證樣本
        val_path = random.sample(images, k=int(len(images) * val_rate))
 
        for img_path in images:
            if img_path in val_path:  # 如果該路徑在采樣的驗(yàn)證集樣本中則存入驗(yàn)證集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否則存入訓(xùn)練集
                train_images_path.append(img_path)
                train_images_label.append(image_class)
 
    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
 
    plot_image = False
    if plot_image:
        # 繪制每種類別個(gè)數(shù)柱狀圖
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 將橫坐標(biāo)0,1,2,3,4替換為相應(yīng)的類別名稱
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱狀圖上添加數(shù)值標(biāo)簽
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 設(shè)置x坐標(biāo)
        plt.xlabel('image class')
        # 設(shè)置y坐標(biāo)
        plt.ylabel('number of images')
        # 設(shè)置柱狀圖的標(biāo)題
        plt.title('flower class distribution')
        plt.show()
 
    return train_images_path, train_images_label, val_images_path, val_images_label

運(yùn)行上述代碼, 得到 class_indices.json 文件,該文件存儲(chǔ)了類別名稱以及每個(gè)類別對(duì)應(yīng)的索引

設(shè)置變量 plot_image 為True,可以將每個(gè)類別的樣本數(shù)以柱狀圖的形式可視化出來

函數(shù) read_split_data 執(zhí)行結(jié)束后,返回四個(gè)列表  : train_images_path 、train_images_label 、val_images_path 和 val_images_label,分別表示訓(xùn)練集的圖像和標(biāo)簽路徑以及驗(yàn)證集的圖像和標(biāo)簽路徑,對(duì)數(shù)據(jù)集完成了訓(xùn)練集和驗(yàn)證集的劃分!

然后對(duì)訓(xùn)練集和驗(yàn)證集中的數(shù)據(jù)進(jìn)行數(shù)據(jù)預(yù)處理,比如裁剪、翻轉(zhuǎn)、歸一化等等操作

接下來,重點(diǎn)來了!

train_data_set = MyDataSet(images_path=train_images_path,
                           images_class=train_images_label,
                           transform=data_transform["train"])

傳入訓(xùn)練集圖像的路徑列表、標(biāo)簽列表以及數(shù)據(jù)預(yù)處理的方法,對(duì)類 MyDataSet 進(jìn)行初始化,得到類 MyDataSet 的實(shí)例對(duì)象 train_data_set

MyDataSet 是一個(gè)自定義的數(shù)據(jù)類,代碼如下:

from PIL import Image
import torch
from torch.utils.data import Dataset
 
 
class MyDataSet(Dataset):
    """自定義數(shù)據(jù)集"""
 
    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform
 
    def __len__(self):
        return len(self.images_path)
 
    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB為彩色圖片,L為灰度圖片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]
 
        if self.transform is not None:
            img = self.transform(img)
 
        return img, label
 
    @staticmethod
    def collate_fn(batch):
        # 官方實(shí)現(xiàn)的default_collate可以參考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))
 
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

該類繼承類Dataset,主要實(shí)現(xiàn)初始化函數(shù)__init__( )、計(jì)算數(shù)據(jù)集中樣本數(shù)量的函數(shù)__len__( )、根據(jù)索引返回相應(yīng)的圖片和標(biāo)簽的函數(shù)__getitem__( ) 以及 collate_fn( ) 函數(shù)

我想要重點(diǎn)闡述一下關(guān)于函數(shù) collate_fn( ) 函數(shù)的作用

collate_fn( ) 函數(shù)決定了如何將數(shù)據(jù)進(jìn)行打包處理

@staticmethod
def collate_fn(batch):
    # 官方實(shí)現(xiàn)的default_collate可以參考
    # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
    images, labels = tuple(zip(*batch))
    images = torch.stack(images, dim=0)
    labels = torch.as_tensor(labels)
    return images, labels

傳入函數(shù)的參數(shù) batch 是由 (images,labels) 組成的一個(gè)個(gè)的元組

如果在此處設(shè)置batch_size的值為8,那么這個(gè)函數(shù)就從數(shù)據(jù)集中獲取8張圖片以及這8張圖片所對(duì)應(yīng)的標(biāo)簽

可以設(shè)置斷點(diǎn)來看一下:

因?yàn)?batch_size 取 8,所以可以看到 batch 是一個(gè)長(zhǎng)度為8的列表,列表是由8個(gè)元組元素組成的,每個(gè)元組是由圖像和其所對(duì)應(yīng)的標(biāo)簽組成的

最后,通過 DataLoader 從實(shí)例化對(duì)象 train_data_set 中加載數(shù)據(jù),打包成一個(gè)一個(gè) batch 送入網(wǎng)絡(luò)中進(jìn)行訓(xùn)練

train_loader = torch.utils.data.DataLoader(train_data_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=nw,
                                           collate_fn=train_data_set.collate_fn)

這樣就可以得到用于加載訓(xùn)練數(shù)據(jù)的數(shù)據(jù)加載器 train_loader

可以將 數(shù)據(jù)加載器 train_loader 傳給函數(shù),通過調(diào)用函數(shù) plot_data_loader_image 后

plot_data_loader_image(train_loader)

這樣就能可視化出數(shù)據(jù)加載器  train_loader 中的內(nèi)容,如圖所示(此處需要將 num_workers 設(shè)置為0)

以上就是在Pytorch中自定義dataset讀取數(shù)據(jù)的實(shí)現(xiàn)代碼的詳細(xì)內(nèi)容,更多關(guān)于在Pytorch自定義dataset的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

最新評(píng)論