Python深度學(xué)習(xí)pytorch實現(xiàn)圖像分類數(shù)據(jù)集
目前廣泛使用的圖像分類數(shù)據(jù)集之一是MNIST數(shù)據(jù)集。如今,MNIST數(shù)據(jù)集更像是一個健全的檢查,而不是一個基準(zhǔn)。
為了提高難度,我們將在接下來的章節(jié)中討論在2017年發(fā)布的性質(zhì)相似但相對復(fù)雜的Fashion-MNIST數(shù)據(jù)集。
import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display()
讀取數(shù)據(jù)集
我們可以通過框架中的內(nèi)置函數(shù)將Fashion-MNIST數(shù)據(jù)集下載并讀取到內(nèi)存中。
# 通過ToTensor實例將圖像數(shù)據(jù)從PIL類型變換成32位浮點數(shù)格式 # 并除以255使得所有像素的數(shù)值均在0到1之間 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvisino.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
Fashion-MNIST由10個類別的圖像組成,每個類別由訓(xùn)練集中的6000張圖像和測試集中的1000張圖像組成。
測試數(shù)據(jù)集(test dataset)不會用于訓(xùn)練,只用于評估模型性能。訓(xùn)練集和測試集分別包含60000和10000張圖像。
len(mnist_train), len(mnist_test)
(60000, 10000)
每個輸入圖像的高度和寬度均為28像素。數(shù)據(jù)集由灰度圖像組成,其通道數(shù)為1。
為了簡潔起見,本篇中,我們將高度h像素,寬度w像素圖像的形狀即為 h×w或 (h,w)。
mnist_train[0][0].shape
torch.size([1, 28, 28])
Fashion-MNIST中包含10個類別分別是
t-shirt(T恤)、trouser(褲⼦)、pullover(套衫)、dress(連⾐裙)、coat(外套)、
sandal(涼鞋)、shirt(襯衫)、sneaker(運動鞋)、bag(包)和ankle boot(短靴)
以下函數(shù)用于在數(shù)字標(biāo)簽索引及其文本名稱之間進(jìn)行轉(zhuǎn)換。
def get_fashion_mnist_labels(labels): """返回Fashion-MNIST數(shù)據(jù)集的本文標(biāo)簽。""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
我們現(xiàn)在可以創(chuàng)建一個函數(shù)來可視化這些樣本。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """Plot a list of images.""" figsize = (num_cols * scale, num_rows * scale) _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): # 圖片張量 ax.imshow(img.numpy()) else: # PIL圖片 ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes
以下是訓(xùn)練數(shù)據(jù)集中前幾個樣本的圖像及其相應(yīng)的標(biāo)簽(文本形式)。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
讀取小批量
為了使我們在讀取訓(xùn)練集和測試集時更容易,我們使用內(nèi)置的數(shù)據(jù)迭代器,而不是從零開始創(chuàng)建一個。回顧一下,在每次迭代中,數(shù)據(jù)加載器每次都會讀取一小批量數(shù)據(jù),大小為batch_size。我們在訓(xùn)練數(shù)據(jù)迭代其中還隨機(jī)打亂了所有樣本
batch_size = 256 def get_dataloader_workers(): """使用4個進(jìn)程來讀取數(shù)據(jù)。""" return 4 train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
整合所有組件
現(xiàn)在我們定義load_data_fashion_mnist函數(shù),用于獲取和讀取Fashion-MNIST數(shù)據(jù)集。它返回訓(xùn)練集和驗證集的數(shù)據(jù)迭代器。此外,它還接受一個可選參數(shù),用來將圖像大小調(diào)整為另一種形狀。
def load_data_fashion_mnist(batch_size, resize=None): """下載Fashion-MNIST數(shù)據(jù)集,然后將其加載到內(nèi)存中。""" trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transforms=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transforms=trans, download=True) return(data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
下面,我們通過指定resize參數(shù)來測試load_data_fashion_mnist函數(shù)的圖像大小調(diào)整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype) break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
以上就是Python深度學(xué)習(xí)pytorch實現(xiàn)圖像分類數(shù)據(jù)集的詳細(xì)內(nèi)容,更多關(guān)于pytorch圖像分類數(shù)據(jù)集的資料請關(guān)注腳本之家其它相關(guān)文章!
- Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)
- python神經(jīng)網(wǎng)絡(luò)AlexNet分類模型訓(xùn)練貓狗數(shù)據(jù)集
- PyTorch手寫數(shù)字?jǐn)?shù)據(jù)集進(jìn)行多分類
- Python機(jī)器學(xué)習(xí)應(yīng)用之基于天氣數(shù)據(jù)集的XGBoost分類篇解讀
- 總結(jié)近幾年P(guān)ytorch基于Imgagenet數(shù)據(jù)集圖像分類模型
- 詳解PyTorch預(yù)定義數(shù)據(jù)集類datasets.ImageFolder使用方法
相關(guān)文章
Jupyter Notebook 遠(yuǎn)程訪問配置詳解
這篇文章主要介紹了Jupyter Notebook 遠(yuǎn)程訪問配置詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-01-01pandas去重復(fù)行并分類匯總的實現(xiàn)方法
這篇文章主要介紹了pandas去重復(fù)行并分類匯總的實現(xiàn)方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-01-01利用scrapy將爬到的數(shù)據(jù)保存到mysql(防止重復(fù))
這篇文章主要給大家介紹了關(guān)于利用scrapy將爬到的數(shù)據(jù)保存到mysql(防止重復(fù))的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面來一起看看吧。2018-03-03python常用函數(shù)random()函數(shù)詳解
這篇文章主要介紹了python常用函數(shù)random()函數(shù),本文通過實例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-02-02