在Pytorch中自定義dataset讀取數(shù)據(jù)的實(shí)現(xiàn)代碼
這里使用的是經(jīng)典的花分類數(shù)據(jù)集
下載地址:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
下載結(jié)束后進(jìn)行解壓,可以得到五種不同種類花的圖片,如上圖所示
主函數(shù) main
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} train_data_set = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) batch_size = 8 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers'.format(nw)) train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=nw, collate_fn=train_data_set.collate_fn) # plot_data_loader_image(train_loader) for step, data in enumerate(train_loader): images, labels = data
其中,
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data ( root )
傳入?yún)?shù) root(就是該數(shù)據(jù)集所在的路徑),沒有傳入?yún)?shù)val_rate就取其默認(rèn)值0.2( 即驗(yàn)證集占整個(gè)數(shù)據(jù)集的 20% ), 調(diào)用函數(shù) read_split_data
def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 保證隨機(jī)結(jié)果可復(fù)現(xiàn) assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 遍歷文件夾,一個(gè)文件夾對(duì)應(yīng)一個(gè)類別 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保證順序一致 flower_class.sort() # 生成類別名稱以及對(duì)應(yīng)的數(shù)字索引 class_indices = dict((k, v) for v, k in enumerate(flower_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存儲(chǔ)訓(xùn)練集的所有圖片路徑 train_images_label = [] # 存儲(chǔ)訓(xùn)練集圖片對(duì)應(yīng)索引信息 val_images_path = [] # 存儲(chǔ)驗(yàn)證集的所有圖片路徑 val_images_label = [] # 存儲(chǔ)驗(yàn)證集圖片對(duì)應(yīng)索引信息 every_class_num = [] # 存儲(chǔ)每個(gè)類別的樣本總數(shù) supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后綴類型 # 遍歷每個(gè)文件夾下的文件 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍歷獲取supported支持的所有文件路徑 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 獲取該類別對(duì)應(yīng)的索引 image_class = class_indices[cla] # 記錄該類別的樣本數(shù)量 every_class_num.append(len(images)) # 按比例隨機(jī)采樣驗(yàn)證樣本 val_path = random.sample(images, k=int(len(images) * val_rate)) for img_path in images: if img_path in val_path: # 如果該路徑在采樣的驗(yàn)證集樣本中則存入驗(yàn)證集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否則存入訓(xùn)練集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) plot_image = False if plot_image: # 繪制每種類別個(gè)數(shù)柱狀圖 plt.bar(range(len(flower_class)), every_class_num, align='center') # 將橫坐標(biāo)0,1,2,3,4替換為相應(yīng)的類別名稱 plt.xticks(range(len(flower_class)), flower_class) # 在柱狀圖上添加數(shù)值標(biāo)簽 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 設(shè)置x坐標(biāo) plt.xlabel('image class') # 設(shè)置y坐標(biāo) plt.ylabel('number of images') # 設(shè)置柱狀圖的標(biāo)題 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_label
運(yùn)行上述代碼, 得到 class_indices.json 文件,該文件存儲(chǔ)了類別名稱以及每個(gè)類別對(duì)應(yīng)的索引
設(shè)置變量 plot_image 為True,可以將每個(gè)類別的樣本數(shù)以柱狀圖的形式可視化出來
函數(shù) read_split_data 執(zhí)行結(jié)束后,返回四個(gè)列表 : train_images_path 、train_images_label 、val_images_path 和 val_images_label,分別表示訓(xùn)練集的圖像和標(biāo)簽路徑以及驗(yàn)證集的圖像和標(biāo)簽路徑,對(duì)數(shù)據(jù)集完成了訓(xùn)練集和驗(yàn)證集的劃分!
然后對(duì)訓(xùn)練集和驗(yàn)證集中的數(shù)據(jù)進(jìn)行數(shù)據(jù)預(yù)處理,比如裁剪、翻轉(zhuǎn)、歸一化等等操作
接下來,重點(diǎn)來了!
train_data_set = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"])
傳入訓(xùn)練集圖像的路徑列表、標(biāo)簽列表以及數(shù)據(jù)預(yù)處理的方法,對(duì)類 MyDataSet 進(jìn)行初始化,得到類 MyDataSet 的實(shí)例對(duì)象 train_data_set
MyDataSet 是一個(gè)自定義的數(shù)據(jù)類,代碼如下:
from PIL import Image import torch from torch.utils.data import Dataset class MyDataSet(Dataset): """自定義數(shù)據(jù)集""" def __init__(self, images_path: list, images_class: list, transform=None): self.images_path = images_path self.images_class = images_class self.transform = transform def __len__(self): return len(self.images_path) def __getitem__(self, item): img = Image.open(self.images_path[item]) # RGB為彩色圖片,L為灰度圖片 if img.mode != 'RGB': raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) label = self.images_class[item] if self.transform is not None: img = self.transform(img) return img, label @staticmethod def collate_fn(batch): # 官方實(shí)現(xiàn)的default_collate可以參考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels
該類繼承類Dataset,主要實(shí)現(xiàn)初始化函數(shù)__init__( )、計(jì)算數(shù)據(jù)集中樣本數(shù)量的函數(shù)__len__( )、根據(jù)索引返回相應(yīng)的圖片和標(biāo)簽的函數(shù)__getitem__( ) 以及 collate_fn( ) 函數(shù)
我想要重點(diǎn)闡述一下關(guān)于函數(shù) collate_fn( ) 函數(shù)的作用
collate_fn( ) 函數(shù)決定了如何將數(shù)據(jù)進(jìn)行打包處理
@staticmethod def collate_fn(batch): # 官方實(shí)現(xiàn)的default_collate可以參考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels
傳入函數(shù)的參數(shù) batch 是由 (images,labels) 組成的一個(gè)個(gè)的元組
如果在此處設(shè)置batch_size的值為8,那么這個(gè)函數(shù)就從數(shù)據(jù)集中獲取8張圖片以及這8張圖片所對(duì)應(yīng)的標(biāo)簽
可以設(shè)置斷點(diǎn)來看一下:
因?yàn)?batch_size 取 8,所以可以看到 batch 是一個(gè)長(zhǎng)度為8的列表,列表是由8個(gè)元組元素組成的,每個(gè)元組是由圖像和其所對(duì)應(yīng)的標(biāo)簽組成的
最后,通過 DataLoader 從實(shí)例化對(duì)象 train_data_set 中加載數(shù)據(jù),打包成一個(gè)一個(gè) batch 送入網(wǎng)絡(luò)中進(jìn)行訓(xùn)練
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=nw, collate_fn=train_data_set.collate_fn)
這樣就可以得到用于加載訓(xùn)練數(shù)據(jù)的數(shù)據(jù)加載器 train_loader
可以將 數(shù)據(jù)加載器 train_loader 傳給函數(shù),通過調(diào)用函數(shù) plot_data_loader_image 后
plot_data_loader_image(train_loader)
這樣就能可視化出數(shù)據(jù)加載器 train_loader 中的內(nèi)容,如圖所示(此處需要將 num_workers 設(shè)置為0)
以上就是在Pytorch中自定義dataset讀取數(shù)據(jù)的實(shí)現(xiàn)代碼的詳細(xì)內(nèi)容,更多關(guān)于在Pytorch自定義dataset的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python中Pycharm 輸出中文或打印中文亂碼現(xiàn)象的解決辦法
本篇文章主要介紹了python中Pycharm 輸出中文或打印中文亂碼現(xiàn)象的解決辦法 ,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2017-06-06基于np.arange與np.linspace細(xì)微區(qū)別(數(shù)據(jù)溢出問題)
這篇文章主要介紹了基于np.arange與np.linspace細(xì)微區(qū)別(數(shù)據(jù)溢出問題),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-05-05Tortoise-orm信號(hào)實(shí)現(xiàn)及使用場(chǎng)景源碼詳解
這篇文章主要為大家介紹了Tortoise-orm信號(hào)實(shí)現(xiàn)及使用場(chǎng)景源碼詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-03-03使用Python FastAPI構(gòu)建Web服務(wù)的實(shí)現(xiàn)
這篇文章主要介紹了使用Python FastAPI構(gòu)建Web服務(wù)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-06-06python 限制函數(shù)調(diào)用次數(shù)的實(shí)例講解
下面小編就為大家分享一篇python 限制函數(shù)調(diào)用次數(shù)的實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-04-04