Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取
更新時間:2020年01月18日 17:20:27 作者:_寒潭雁影
今天小編就為大家分享一篇Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
以讀取VOC2012語義分割數(shù)據(jù)集為例,具體見代碼注釋:
VocDataset.py
from PIL import Image import torch import torch.utils.data as data import numpy as np import os import torchvision import torchvision.transforms as transforms import time #VOC數(shù)據(jù)集分類對應(yīng)顏色標(biāo)簽 VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] #顏色標(biāo)簽空間轉(zhuǎn)到序號標(biāo)簽空間,就他媽這里浪費巨量的時間,這里還他媽的有問題 def voc_label_indices(colormap, colormap2label): """Assign label indices for Pascal VOC2012 Dataset.""" idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0]) #out = np.empty(idx.shape, dtype = np.int64) out = colormap2label[idx] out=out.astype(np.int64)#數(shù)據(jù)類型轉(zhuǎn)換 end = time.time() return out class MyDataset(data.Dataset):#創(chuàng)建自定義的數(shù)據(jù)讀取類 def __init__(self, root, is_train, crop_size=(320,480)): self.rgb_mean =(0.485, 0.456, 0.406) self.rgb_std = (0.229, 0.224, 0.225) self.root=root self.crop_size=crop_size images = []#創(chuàng)建空列表存文件名稱 txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt') with open(txt_fname, 'r') as f: self.images = f.read().split() #數(shù)據(jù)名稱整理 self.files = [] for name in self.images: img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name) label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name) self.files.append({ "img": img_file, "label": label_file, "name": name }) self.colormap2label = np.zeros(256**3) #整個循環(huán)的意思就是將顏色標(biāo)簽映射為單通道的數(shù)組索引 for i, cm in enumerate(VOC_COLORMAP): self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i #按照索引讀取每個元素的具體內(nèi)容 def __getitem__(self, index): datafiles = self.files[index] name = datafiles["name"] image = Image.open(datafiles["img"]) label = Image.open(datafiles["label"]).convert('RGB')#打開的是PNG格式的圖片要轉(zhuǎn)到rgb的格式下,不然結(jié)果會比較要命 #以圖像中心為中心截取固定大小圖像,小于固定大小的圖像則自動填0 imgCenterCrop = transforms.Compose([ transforms.CenterCrop(self.crop_size), transforms.ToTensor(), transforms.Normalize(self.rgb_mean, self.rgb_std),#圖像數(shù)據(jù)正則化 ]) labelCenterCrop = transforms.CenterCrop(self.crop_size) cropImage=imgCenterCrop(image) croplabel=labelCenterCrop(label) croplabel=torch.from_numpy(np.array(croplabel)).long()#把標(biāo)簽數(shù)據(jù)類型轉(zhuǎn)為torch #將顏色標(biāo)簽圖轉(zhuǎn)為序號標(biāo)簽圖 mylabel=voc_label_indices(croplabel, self.colormap2label) return cropImage,mylabel #返回圖像數(shù)據(jù)長度 def __len__(self): return len(self.files)
Train.py
import matplotlib.pyplot as plt import torch.utils.data as data import torchvision.transforms as transforms import numpy as np from PIL import Image from VocDataset import MyDataset #VOC數(shù)據(jù)集分類對應(yīng)顏色標(biāo)簽 VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] root='../data/VOCdevkit/VOC2012' train_data=MyDataset(root,True) trainloader = data.DataLoader(train_data, 4) #從數(shù)據(jù)集中拿出一個批次的數(shù)據(jù) for i, data in enumerate(trainloader): getimgs, labels= data img = transforms.ToPILImage()(getimgs[0]) labels = labels.numpy()#tensor轉(zhuǎn)numpy labels=labels[0]#獲得批次標(biāo)簽集中的一張標(biāo)簽圖像 labels = labels.transpose((1,0))#數(shù)組維度切換,將第1維換到第0維,第0維換到第1維 ##將單通道索引標(biāo)簽圖片映射回顏色標(biāo)簽圖片 newIm= Image.new('RGB', (480, 320))#創(chuàng)建一張與標(biāo)簽大小相同的圖片,用以顯示標(biāo)簽所對應(yīng)的顏色 for i in range(0, 480): for j in range(0, 320): sele=labels[i][j]#取得坐標(biāo)點對應(yīng)像素的值 newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2]))) #顯示圖像和標(biāo)簽 plt.figure("image") ax1 = plt.subplot(1,2,1) ax2 = plt.subplot(1,2,2) plt.sca(ax1) plt.imshow(img) plt.sca(ax2) plt.imshow(newIm) plt.show()
以上這篇Pytorch 實現(xiàn)數(shù)據(jù)集自定義讀取就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Scrapy-Redis結(jié)合POST請求獲取數(shù)據(jù)的方法示例
這篇文章主要給大家介紹了關(guān)于Scrapy-Redis結(jié)合POST請求獲取數(shù)據(jù)的相關(guān)資料,文中通過示例代碼介紹的非常詳細,對大家學(xué)習(xí)或者使用Scrapy-Redis具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起學(xué)習(xí)學(xué)習(xí)吧2019-05-05python超詳細實現(xiàn)完整學(xué)生成績管理系統(tǒng)
讀萬卷書不如行萬里路,只學(xué)書上的理論是遠遠不夠的,只有在實戰(zhàn)中才能獲得能力的提升,本篇文章手把手帶你用Java實現(xiàn)一個完整版學(xué)生成績管理系統(tǒng),大家可以在過程中查缺補漏,提升水平2022-03-03Python實現(xiàn)批量繪制遙感影像數(shù)據(jù)的直方圖
這篇文章主要為大家詳細介紹了如何基于Python中g(shù)dal模塊,實現(xiàn)對大量柵格圖像批量繪制直方圖,文中的示例代碼講解詳細,感興趣的小伙伴可以了解一下2023-02-02