pytorch中的dataset用法詳解
1.torch.utils.data 里面的dataset使用方法
當(dāng)我們繼承了一個 Dataset類之后,我們需要重寫 len 方法,該方法提供了dataset的大??; getitem 方法, 該方法支持從 0 到 len(self)
的索引
from torch.utils.data import Dataset, DataLoader import torch class MyDataset(Dataset): ? ? """ ? ? ? ? 下載數(shù)據(jù)、初始化數(shù)據(jù),都可以在這里完成 ? ? """ ? ? def __init__(self): ? ? ? ? self.x = torch.linspace(11,20,10) ? ? ? ? self.y = torch.linspace(1,10,10) ? ? ? ? self.len = len(self.x) ? ? def __getitem__(self, index): ? ? ? ? return self.x[index], self.y[index] ? ? def __len__(self): ? ? ? ? return self.len # 實例化這個類,然后我們就得到了Dataset類型的數(shù)據(jù),記下來就將這個類傳給DataLoader,就可以了。 mydataset = MyDataset()#[return: # ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?# (tensor(x1),tensor(y1)); # ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?# (tensor(x2),tensor(y2)); # ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?# ...... train_loader2 = DataLoader(dataset=mydataset, ? ? ? ? ? ? ? ? ? ? ? ? ? ?batch_size=5, ? ? ? ? ? ? ? ? ? ? ? ? ? ?shuffle=False) for epoch in range(3): ?# 訓(xùn)練所有!整套!數(shù)據(jù) 3 次 ? ? for step,(batch_x,batch_y) in enumerate(train_loader2): ?# 每一步 loader 釋放一小批數(shù)據(jù)用來學(xué)習(xí) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? #return: ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? #(tensor(x1,x2,x3,x4,x5),tensor(y1,y2,y3,y4,y5)) ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? #(tensor(x6,x7,x8,x9,x10),tensor(y6,y7,y8,y9,y10)) ? ? ? ? # 假設(shè)這里就是你訓(xùn)練的地方... ? ? ? ? # 打出來一些數(shù)據(jù) ? ? ? ? print('Epoch: ', epoch, '| Step:', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
2.torchvision.datasets的使用方法
torchvision
中datasets
中所有封裝的數(shù)據(jù)集都是torch.utils.data.Dataset
的子類,它們都實現(xiàn)了__getitem__和__len__方法。因此,它們都可以用torch.utils.data.DataLoader進行數(shù)據(jù)加載。
用法1:使用官方數(shù)據(jù)集
可選數(shù)據(jù)集參考:https://www.pianshen.com/article/9695297328/
代碼:
torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor()) root (string): 表示數(shù)據(jù)集的根目錄,其中根目錄存在CIFAR10/processed/training.pt和CIFAR10/processed/test.pt的子目錄 train (bool, optional): 如果為True,則從training.pt創(chuàng)建數(shù)據(jù)集,否則從test.pt創(chuàng)建數(shù)據(jù)集 download (bool, optional): 如果為True,則從internet下載數(shù)據(jù)集并將其放入根目錄。如果數(shù)據(jù)集已下載,則不會再次下載 transform (callable, optional): 接收PIL圖片并返回轉(zhuǎn)換后版本圖片的轉(zhuǎn)換函數(shù) target_transform (callable, optional): 接收PIL接收目標(biāo)并對其進行變換的轉(zhuǎn)換函數(shù)
import torchvision # 準(zhǔn)備的測試數(shù)據(jù)集 from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True) # 測試數(shù)據(jù)集中第一張圖片及target img, target = test_data[0] print(img.shape) print(target) writer = SummaryWriter("dataloader") for epoch in range(2): ? ? step = 0 ? ? for data in test_loader: ? ? ? ? imgs, targets = data ? ? ? ? # print(imgs.shape) ? ? ? ? # print(targets) ? ? ? ? writer.add_images("Epoch: {}".format(epoch), imgs, step) ? ? ? ? step = step + 1 writer.close()
用法2:ImageFolder通用的自己數(shù)據(jù)集加載器
一個通用的數(shù)據(jù)加載器,數(shù)據(jù)集中的數(shù)據(jù)以以下方式組織
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
torchvision.datasets.ImageFolder(root="root folder path", [transform, target_transform])
ImageFolder有以下成員變量:
- self.classes - 用一個list保存 類名
- self.class_to_idx - 類名對應(yīng)的 索引
- self.imgs - 保存(img-path, class) tuple的list
該方法可以結(jié)合torch.utils.data.Subset
使用 ,以根據(jù)示例索引將您的ImageFolder數(shù)據(jù)集分為訓(xùn)練和測試。
orig_set = torchvision.datasets.Imagefolder('dataset/') ?# your dataset n = len(orig_set) ?# total number of examples n_test = int(0.1 * n) ?# take ~10% for test test_set = torch.utils.data.Subset(orig_set, range(n_test)) ?# take first 10% train_set = torch.utils.data.Subset(orig_set, range(n_test, n)) ?# take the rest?
到此這篇關(guān)于pytorch
的dataset
用法詳解的文章就介紹到這了,更多相關(guān)pytorch的dataset用法內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!