PyTorch如何創(chuàng)建自己的數(shù)據(jù)集
PyTorch創(chuàng)建自己的數(shù)據(jù)集
圖片文件在同一的文件夾下
思路是繼承 torch.utils.data.Dataset,并重點(diǎn)重寫其 __getitem__方法,示例代碼如下:
class ImageFolder(Dataset): ? ? def __init__(self, folder_path): ? ? ? ? self.files = sorted(glob.glob('%s/*.*' % folder_path)) ? ? def __getitem__(self, index): ? ? ? ? path = self.files[index % len(self.files)] ? ? ? ? img = np.array(Image.open(path)) ? ? ? ? h, w, c = img.shape ? ? ? ? pad = ((40, 40), (4, 4), (0, 0)) ? ? ? ? # img = np.pad(img, pad, 'constant', constant_values=0) / 255 ? ? ? ? img = np.pad(img, pad, mode='edge') / 255.0 ? ? ? ? img = torch.from_numpy(img).float() ? ? ? ? patches = np.reshape(img, (3, 10, 128, 11, 128)) ? ? ? ? patches = np.transpose(patches, (0, 1, 3, 2, 4)) ? ? ? ? return img, patches, path ? ? def __len__(self): ? ? ? ? return len(self.files)
圖片文件在不同的文件夾下
比如我們有數(shù)據(jù)如下:
─── data
├── train
│ ├── 0.jpg
│ └── 1.jpg
├── test
│ ├── 0.jpg
│ └── 1.jpg
└── val
├── 1.jpg
└── 2.jpg
此時(shí)我們只需要將以上代碼稍作修改即可,修改的代碼如下:
self.files = sorted(glob.glob('%s/**/*.*' % folder_path, recursive=True))
其他代碼不變。
pytorch常用數(shù)據(jù)集的使用
對(duì)于pytorch數(shù)據(jù)集的使用,示例代碼如下:
from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from torchvision import transforms import torchvision import ssl ssl._create_default_https_context = ssl._create_unverified_context dataset_transform = Compose([transforms.ToTensor()]) # 關(guān)于官方數(shù)據(jù)集的使用還是關(guān)鍵要看pytorch的官方文檔 train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True) test_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=False,transform=dataset_transform,download=True) # 查看測試數(shù)據(jù)集中的第一個(gè)數(shù)據(jù) # print(test_set[0]) # 查看測試數(shù)據(jù)集中的分類情況 # print(test_set.classes) # # 取出第一個(gè)數(shù)據(jù)中的圖片(img)和分類結(jié)果(target) # img,target = test_set[0] # 查看圖片數(shù)據(jù)的類型 # print(img) # print(target) # 輸出類別 # print(test_set.classes[target]) # 查看圖片 # img.show() # 使用tensorboard顯示tensor數(shù)據(jù)類型的圖片 writer = SummaryWriter("logs") for i in range(10): # 取出數(shù)據(jù)中的圖片(img)和分類結(jié)果(target) img,target = test_set[i] writer.add_image("test_set",img,i) writer.close()
上述代碼運(yùn)行結(jié)果在tensorboard可視化:
代碼
train_set = torchvision.datasets.CIFAR10(root="./CIFAR10",train=True,transform=dataset_transform,download=True)
常用參數(shù)講解
root
:根目錄,存放數(shù)據(jù)集的位置train
:若為True,則劃分為訓(xùn)練數(shù)據(jù)集,若為False,則劃分為測試數(shù)據(jù)集transform
:指定輸入數(shù)據(jù)集處理方式download
:若為True,則會(huì)將數(shù)據(jù)集下載到root指定的目錄下,否則不會(huì)下載
官方文檔對(duì)參數(shù)的解釋:
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
注意:
- 關(guān)于官方數(shù)據(jù)集的使用還是關(guān)鍵要看pytorch的官方文檔
- 下載數(shù)據(jù)集的細(xì)節(jié)之處:知道下載鏈接(下載鏈接可以在源碼中查看)之后可以不用使用代碼下載了,使用迅雷來下載可能會(huì)更快。
- 要學(xué)會(huì)使用Pycharm中的ctrl+p和ctrl+alt這兩個(gè)快捷鍵
- pytorch官網(wǎng)
- pytorch官方數(shù)據(jù)集(下載數(shù)據(jù)集方法)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python+OpenCV實(shí)現(xiàn)角度測量的示例代碼
本文介紹如何使用python語言實(shí)現(xiàn)角度測量,程序包括鼠標(biāo)選點(diǎn)、直線斜率計(jì)算、角度計(jì)算三個(gè)子程序和一個(gè)主程序,感興趣的可以了解一下2022-03-03Python的string模塊中的Template類字符串模板用法
通過string.Template我們可以為Python定制字符串的替換標(biāo)準(zhǔn),這里我們就來通過示例解析Python的string模塊中的Template類字符串模板用法:2016-06-06python制作可視化GUI界面自動(dòng)分類管理文件
這篇文章主要為大家介紹了python制作可視化GUI界面實(shí)現(xiàn)自動(dòng)分類管理文件,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05基于Tensorflow批量數(shù)據(jù)的輸入實(shí)現(xiàn)方式
今天小編就為大家分享一篇基于Tensorflow批量數(shù)據(jù)的輸入實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02Python中用append()連接后多出一列Unnamed的解決
Python中用append()連接后多出一列Unnamed的解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01Python基礎(chǔ)教程之Pandas數(shù)據(jù)分析庫詳解
Pandas是一個(gè)基于 NumPy 的非常強(qiáng)大的開源數(shù)據(jù)處理庫,它提供了高效、靈活和豐富的數(shù)據(jù)結(jié)構(gòu)和數(shù)據(jù)分析工具,本文中,我們將學(xué)習(xí)如何使用Pandas來處理和分析數(shù)據(jù),感興趣的小伙伴跟著小編一起來看看吧2023-07-07pyqt 實(shí)現(xiàn)為長內(nèi)容添加滑輪 scrollArea
今天小編就為大家分享一篇pyqt 實(shí)現(xiàn)為長內(nèi)容添加滑輪 scrollArea,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-06-06關(guān)于Python時(shí)間日期常見的一些操作方法
Python的datetime模塊是處理日期和時(shí)間的強(qiáng)大工具,datetime類可以獲取當(dāng)前時(shí)間、指定日期、計(jì)算時(shí)間差、訪問時(shí)間屬性及格式化時(shí)間,這些功能使得在Python中進(jìn)行時(shí)間日期處理變得簡單高效,需要的朋友可以參考下2024-09-09