使用python/pytorch讀取數(shù)據(jù)集的示例代碼
MNIST數(shù)據(jù)集
MNIST數(shù)據(jù)集包含了6萬張手寫數(shù)字([1,28,28]尺寸),以特殊格式存儲。本文首先將MNIST數(shù)據(jù)集另存為png格式,然后再讀取png格式圖片,開展后續(xù)訓(xùn)練
另存為png格式
import torch from torch.utils.data import Dataset from torchvision.datasets import MNIST from torch.utils.data import DataLoader from tqdm import tqdm from torchvision import models, transforms from torchvision.utils import save_image from PIL import Image #將MNIST數(shù)據(jù)集轉(zhuǎn)換為圖片 tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1 datasetMNIST = MNIST("./data", train=True, download=True, transform=tf) pbar = tqdm(datasetMNIST) for index, (img,cl) in enumerate(pbar): save_image(img, f"./data/MNIST_PNG/x/{index}.png") # 以寫入模式打開文件 with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file: # 將字符串寫入文件 file.write(f"{cl}")
注意:MNIST源數(shù)據(jù)存放在./data文件下,如果沒有數(shù)據(jù)也沒關(guān)系,代碼會自動從網(wǎng)上下載。另存為png的數(shù)據(jù)放在了./data/MNIST_PNG/文件下。子文件夾x存放6萬張圖片,子文件夾c存放6萬個文本文件,每個文本文件內(nèi)有一行字符串,說明該對應(yīng)的手寫數(shù)字是幾(標(biāo)簽)。
讀取png格式數(shù)據(jù)集
class MyMNISTDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): x = self.data[idx][0] #圖像 y = self.data[idx][1] #標(biāo)簽 return x, y def load_data(dataNum=60000): data = [] pbar = tqdm(range(dataNum)) for i in pbar: # 指定圖片路徑 image_path = f'./data/MNIST_PNG/x/{i}.png' cond_path=f'./data/MNIST_PNG/c/{i}.txt' # 定義圖像預(yù)處理 preprocess = transforms.Compose([ transforms.Grayscale(num_output_channels=1), # 將圖像轉(zhuǎn)換為灰度圖像(單通道) transforms.ToTensor() ]) # 使用預(yù)處理加載圖像 image_tensor = preprocess(Image.open(image_path)) # 加載條件文檔(tag) with open(cond_path, 'r') as file: line = file.readline() number = int(line) # 將字符串轉(zhuǎn)換為整數(shù),圖像的類別 data.append((image_tensor, number)) return data data=load_data(60000) # 創(chuàng)建數(shù)據(jù)集實例 dataset = MyMNISTDataset(data) # 創(chuàng)建數(shù)據(jù)加載器 dataloader = DataLoader(dataset, batch_size=4, shuffle=True) pbar = tqdm(dataloader) for index, (img,cond) in enumerate(pbar): #這里對每一批進(jìn)行訓(xùn)練... print(f"Batch {index}: img = {img.shape}, cond = {cond}")
load_data函數(shù)用于讀取數(shù)據(jù)文件,返回一個data張量。data張量又被用于構(gòu)造MyMNISTDataset類的對象dataset,dataset對象又被DataLoader函數(shù)轉(zhuǎn)換為dataloader。
dataloader事實上按照batch將數(shù)據(jù)集進(jìn)行了分割,4張圖片一組進(jìn)行訓(xùn)練。上述代碼的輸出如下:
...... Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2]) Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0]) Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9]) Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5]) Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4]) Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6]) Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5]) Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1]) Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7]) Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7]) Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6]) Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5]) Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9]) Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6]) Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9]) Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8]) Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6]) ......
到此這篇關(guān)于使用python/pytorch讀取數(shù)據(jù)集的示例代碼的文章就介紹到這了,更多相關(guān)python/pytorch讀取數(shù)據(jù)集內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中NameError: name ‘Image‘ is not&nb
本文主要介紹了Python中NameError: name ‘Image‘ is not defined的問題解決,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2024-06-06