Pytorch DataLoader shuffle驗證方式
shuffle = False時,不打亂數(shù)據(jù)順序
shuffle = True,隨機打亂
import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset
h5f = h5py.File('train.h5', 'w');
data1 = np.array([[1,2,3],
[2,5,6],
[3,5,6],
[4,5,6]])
data2 = np.array([[1,1,1],
[1,2,6],
[1,3,6],
[1,4,6]])
h5f.create_dataset(str('data'), data=data1)
h5f.create_dataset(str('label'), data=data2)
class Dataset(Dataset):
def __init__(self):
h5f = h5py.File('train.h5', 'r')
self.data = h5f['data']
self.label = h5f['label']
def __getitem__(self, index):
data = torch.from_numpy(self.data[index])
label = torch.from_numpy(self.label[index])
return data, label
def __len__(self):
assert self.data.shape[0] == self.label.shape[0], "wrong data length"
return self.data.shape[0]
dataset_train = Dataset()
loader_train = DataLoader(dataset=dataset_train,
batch_size=2,
shuffle = True)
for i, data in enumerate(loader_train):
train_data, label = data
print(train_data)
pytorch DataLoader使用細節(jié)
背景:
我一開始是對數(shù)據(jù)擴增這一塊有疑問, 只看到了數(shù)據(jù)變換(torchvisiom.transforms),但是沒看到數(shù)據(jù)擴增, 后來搞明白了, 數(shù)據(jù)擴增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多個epoch共同作用下完成的,
數(shù)據(jù)變換共有以下內容
composed = transforms.Compose([transforms.Resize((448, 448)), # resize
transforms.RandomCrop(300), # random crop
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], # normalize
std=[0.5, 0.5, 0.5])])
簡單的數(shù)據(jù)讀取類, 進返回PIL格式的image:
class MyDataset(data.Dataset):
def __init__(self, labels_file, root_dir, transform=None):
with open(labels_file) as csvfile:
self.labels_file = list(csv.reader(csvfile))
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.labels_file)
def __getitem__(self, idx):
im_name = os.path.join(root_dir, self.labels_file[idx][0])
im = Image.open(im_name)
if self.transform:
im = self.transform(im)
return im
下面是主程序
labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform=composed)
dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)
"""原始數(shù)據(jù)集共3張圖片, 以batch_size=1, epoch為2 展示所有圖片(共6張) """
for eopch in range(2):
plt.figure(figsize=(6, 6))
for ind, i in enumerate(dataloader):
a = i[0, :, :, :].numpy().transpose((1, 2, 0))
plt.subplot(1, 3, ind+1)
plt.imshow(a)

從上述圖片總可以看到, 在每個eopch階段實際上是對原始圖片重新使用了transform, , 這就造就了數(shù)據(jù)的擴增
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Python使用 Beanstalkd 做異步任務處理的方法
這篇文章主要介紹了Python使用 Beanstalkd 做異步任務處理的方法,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2018-04-04
windows下安裝python的C擴展編譯環(huán)境(解決Unable to find vcvarsall.bat)
這篇文章主要介紹了windows下安裝python的C擴展編譯環(huán)境(解決Unable to find vcvarsall.bat),需要的朋友可以參考下2018-02-02
Pycharm中安裝Pygal并使用Pygal模擬擲骰子(推薦)
這篇文章主要介紹了Pycharm中安裝Pygal并使用Pygal模擬擲骰子,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04
Numpy將二維數(shù)組添加到空數(shù)組的實現(xiàn)
今天小編就為大家分享一篇Numpy將二維數(shù)組添加到空數(shù)組的實現(xiàn),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12
python GUI庫圖形界面開發(fā)之PyQt5打印控件QPrinter詳細使用方法與實例
這篇文章主要介紹了python GUI庫圖形界面開發(fā)之PyQt5打印控件QPrinter詳細使用方法與實例,需要的朋友可以參考下2020-02-02

