深度學(xué)習(xí)入門之Pytorch 數(shù)據(jù)增強(qiáng)的實(shí)現(xiàn)
數(shù)據(jù)增強(qiáng)
卷積神經(jīng)網(wǎng)絡(luò)非常容易出現(xiàn)過擬合的問題,而數(shù)據(jù)增強(qiáng)的方法是對(duì)抗過擬合問題的一個(gè)重要方法。
2012 年 AlexNet 在 ImageNet 上大獲全勝,圖片增強(qiáng)方法功不可沒,因?yàn)橛辛藞D片增強(qiáng),使得訓(xùn)練的數(shù)據(jù)集比實(shí)際數(shù)據(jù)集多了很多'新'樣本,減少了過擬合的問題,下面我們來具體解釋一下。
常用的數(shù)據(jù)增強(qiáng)方法
常用的數(shù)據(jù)增強(qiáng)方法如下:
1.對(duì)圖片進(jìn)行一定比例縮放
2.對(duì)圖片進(jìn)行隨機(jī)位置的截取
3.對(duì)圖片進(jìn)行隨機(jī)的水平和豎直翻轉(zhuǎn)
4.對(duì)圖片進(jìn)行隨機(jī)角度的旋轉(zhuǎn)
5.對(duì)圖片進(jìn)行亮度、對(duì)比度和顏色的隨機(jī)變化
這些方法 pytorch 都已經(jīng)為我們內(nèi)置在了 torchvision 里面,我們?cè)诎惭b pytorch 的時(shí)候也安裝了 torchvision,下面我們來依次展示一下這些數(shù)據(jù)增強(qiáng)方法。
import sys sys.path.append('..') from PIL import Image from torchvision import transforms as tfs # 讀入一張圖片 im = Image.open('./cat.png') im
隨機(jī)比例放縮
隨機(jī)比例縮放主要使用的是 torchvision.transforms.Resize()
這個(gè)函數(shù),第一個(gè)參數(shù)可以是一個(gè)整數(shù),那么圖片會(huì)保存現(xiàn)在的寬和高的比例,并將更短的邊縮放到這個(gè)整數(shù)的大小,第一個(gè)參數(shù)也可以是一個(gè) tuple,那么圖片會(huì)直接把寬和高縮放到這個(gè)大??;第二個(gè)參數(shù)表示放縮圖片使用的方法,比如最鄰近法,或者雙線性差值等,一般雙線性差值能夠保留圖片更多的信息,所以 pytorch 默認(rèn)使用的是雙線性差值,你可以手動(dòng)去改這個(gè)參數(shù),更多的信息可以看看文檔
# 比例縮放 print('before scale, shape: {}'.format(im.size)) new_im = tfs.Resize((100, 200))(im) print('after scale, shape: {}'.format(new_im.size)) new_im
隨機(jī)位置截取
隨機(jī)位置截取能夠提取出圖片中局部的信息,使得網(wǎng)絡(luò)接受的輸入具有多尺度的特征,所以能夠有較好的效果。在 torchvision 中主要有下面兩種方式,一個(gè)是 torchvision.transforms.RandomCrop()
,傳入的參數(shù)就是截取出的圖片的長(zhǎng)和寬,對(duì)圖片在隨機(jī)位置進(jìn)行截??;第二個(gè)是 torchvision.transforms.CenterCrop()
,同樣傳入介曲初的圖片的大小作為參數(shù),會(huì)在圖片的中心進(jìn)行截取
# 隨機(jī)裁剪出 100 x 100 的區(qū)域 random_im1 = tfs.RandomCrop(100)(im) random_im1
# 中心裁剪出 100 x 100 的區(qū)域 center_im = tfs.CenterCrop(100)(im) center_im
隨機(jī)的水平和豎直方向翻轉(zhuǎn)
對(duì)于上面這一張貓的圖片,如果我們將它翻轉(zhuǎn)一下,它仍然是一張貓,但是圖片就有了更多的多樣性,所以隨機(jī)翻轉(zhuǎn)也是一種非常有效的手段。在 torchvision 中,隨機(jī)翻轉(zhuǎn)使用的是 torchvision.transforms.RandomHorizontalFlip()
和 torchvision.transforms.RandomVerticalFlip()
# 隨機(jī)水平翻轉(zhuǎn) h_filp = tfs.RandomHorizontalFlip()(im) h_filp
# 隨機(jī)豎直翻轉(zhuǎn) v_flip = tfs.RandomVerticalFlip()(im) v_flip
隨機(jī)角度旋轉(zhuǎn)
一些角度的旋轉(zhuǎn)仍然是非常有用的數(shù)據(jù)增強(qiáng)方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation()
來實(shí)現(xiàn),其中第一個(gè)參數(shù)就是隨機(jī)旋轉(zhuǎn)的角度,比如填入 10,那么每次圖片就會(huì)在 -10 ~ 10 度之間隨機(jī)旋轉(zhuǎn)
rot_im = tfs.RandomRotation(45)(im) rot_im
亮度、對(duì)比度和顏色的變化
除了形狀變化外,顏色變化又是另外一種增強(qiáng)方式,其中可以設(shè)置亮度變化,對(duì)比度變化和顏色變化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 來實(shí)現(xiàn)的,第一個(gè)參數(shù)就是亮度的比例,第二個(gè)是對(duì)比度,第三個(gè)是飽和度,第四個(gè)是顏色
# 亮度 bright_im = tfs.ColorJitter(brightness=1)(im) # 隨機(jī)從 0 ~ 2 之間亮度變化,1 表示原圖 bright_im
# 對(duì)比度 contrast_im = tfs.ColorJitter(contrast=1)(im) # 隨機(jī)從 0 ~ 2 之間對(duì)比度變化,1 表示原圖 contrast_im
# 顏色 color_im = tfs.ColorJitter(hue=0.5)(im) # 隨機(jī)從 -0.5 ~ 0.5 之間對(duì)顏色變化 color_im
上面我們講了這么圖片增強(qiáng)的方法,其實(shí)這些方法都不是孤立起來用的,可以聯(lián)合起來用,比如先做隨機(jī)翻轉(zhuǎn),然后隨機(jī)截取,再做對(duì)比度增強(qiáng)等等,torchvision 里面有個(gè)非常方便的函數(shù)能夠?qū)⑦@些變化合起來,就是 torchvision.transforms.Compose(),下面我們舉個(gè)例子
im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5) ])
import matplotlib.pyplot as plt %matplotlib inline nrows = 3 ncols = 3 figsize = (8, 8) _, figs = plt.subplots(nrows, ncols, figsize=figsize) for i in range(nrows): for j in range(ncols): figs[i][j].imshow(im_aug(im)) figs[i][j].axes.get_xaxis().set_visible(False) figs[i][j].axes.get_yaxis().set_visible(False) plt.show()
可以看到每次做完增強(qiáng)之后的圖片都有一些變化,所以這就是我們前面講的,增加了一些'新'數(shù)據(jù)
下面我們使用圖像增強(qiáng)進(jìn)行訓(xùn)練網(wǎng)絡(luò),看看具體的提升究竟在什么地方,使用 ResNet 進(jìn)行訓(xùn)練
使用數(shù)據(jù)增強(qiáng)
import numpy as np import torch from torch import nn import torch.nn.functional as F from torch.autograd import Variable from torchvision.datasets import CIFAR10 from utils import train, resnet from torchvision import transforms as tfs # 使用數(shù)據(jù)增強(qiáng) def train_tf(x): im_aug = tfs.Compose([ tfs.Resize(120), tfs.RandomHorizontalFlip(), tfs.RandomCrop(96), tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x def test_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=train_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=test_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
不使用數(shù)據(jù)增強(qiáng)
# 不使用數(shù)據(jù)增強(qiáng) def data_tf(x): im_aug = tfs.Compose([ tfs.Resize(96), tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) x = im_aug(x) return x train_set = CIFAR10('./data', train=True, transform=data_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=data_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = resnet(3, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.01) criterion = nn.CrossEntropyLoss() train(net, train_data, test_data, 10, optimizer, criterion)
從上面可以看出,對(duì)于訓(xùn)練集,不做數(shù)據(jù)增強(qiáng)跑 10 次,準(zhǔn)確率已經(jīng)到了 95%,而使用了數(shù)據(jù)增強(qiáng),跑 10 次準(zhǔn)確率只有 75%,說明數(shù)據(jù)增強(qiáng)之后變得更難了。
而對(duì)于測(cè)試集,使用數(shù)據(jù)增強(qiáng)進(jìn)行訓(xùn)練的時(shí)候,準(zhǔn)確率會(huì)比不使用更高,因?yàn)閿?shù)據(jù)增強(qiáng)提高了模型應(yīng)對(duì)于更多的不同數(shù)據(jù)集的泛化能力,所以有更好的效果。
以上就是深度學(xué)習(xí)入門之Pytorch 數(shù)據(jù)增強(qiáng)的實(shí)現(xiàn)的詳細(xì)內(nèi)容,更多關(guān)于Pytorch 數(shù)據(jù)增強(qiáng)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
用Python將GIF動(dòng)圖分解成多張靜態(tài)圖片
今天給大家?guī)淼氖顷P(guān)于Python的相關(guān)知識(shí),文章圍繞著如何用Python將GIF動(dòng)圖分解成多張靜態(tài)圖片展開,文中有非常詳細(xì)的介紹,需要的朋友可以參考下2021-06-06Python獲取本機(jī)IP/MAC多網(wǎng)卡方法示例
這篇文章主要為大家介紹了Python獲取本機(jī)IP/MAC多網(wǎng)卡方法示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-08-08python爬蟲之你好,李煥英電影票房數(shù)據(jù)分析
這篇文章主要介紹了python爬蟲之你好,李煥英電影票房數(shù)據(jù)分析,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python爬蟲的小伙伴們有一定的幫助,需要的朋友可以參考下2021-04-04python應(yīng)用之如何使用Python發(fā)送通知到微信
現(xiàn)在通過發(fā)微信信息來做消息通知和告警已經(jīng)很普遍了,下面這篇文章主要給大家介紹了關(guān)于python應(yīng)用之如何使用Python發(fā)送通知到微信的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-03-0310個(gè)Python實(shí)現(xiàn)的最頻繁使用的聚類算法
聚類或聚類分析是無監(jiān)督學(xué)習(xí)問題。它通常被用作數(shù)據(jù)分析技術(shù),用于發(fā)現(xiàn)數(shù)據(jù)中的有趣模式。本文為大家介紹了10個(gè)最頻繁使用的聚類算法,感興趣的可以了解一下2022-12-12Python的flask接收前臺(tái)的ajax的post數(shù)據(jù)和get數(shù)據(jù)的方法
這篇文章主要介紹了Python的flask接收前臺(tái)的ajax的post數(shù)據(jù)和get數(shù)據(jù)的方法,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-04-04