pytorch加載的cifar10數(shù)據(jù)集過(guò)程詳解
pytorch怎么加載cifar10數(shù)據(jù)集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多數(shù)計(jì)算機(jī)視覺(jué)領(lǐng)域相關(guān)任務(wù)的數(shù)據(jù)集,可以根據(jù)實(shí)際需要加載相關(guān)數(shù)據(jù)集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就調(diào)用torchvision.datasets.SVHN()。
針對(duì)cifar10數(shù)據(jù)集而言,調(diào)用torchvision.datasets.CIFAR10(),其中root是下載數(shù)據(jù)集后保存的位置;train是一個(gè)bool變量,為true就是訓(xùn)練數(shù)據(jù)集,false就是測(cè)試數(shù)據(jù)集;download也是一個(gè)bool變量,表示是否下載;transform是對(duì)數(shù)據(jù)集中的"image"進(jìn)行一些操作,比如歸一化、隨機(jī)裁剪、各種數(shù)據(jù)增強(qiáng)操作等;target_transform是針對(duì)數(shù)據(jù)集中的"label"進(jìn)行一些操作。
示例代碼如下:
# 加載訓(xùn)練數(shù)據(jù)集 train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化 ]) ) # 加載測(cè)試數(shù)據(jù)集 test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 歸一化 ]) )
transforms.Normalize()進(jìn)行歸一化到底在哪里起作用?【CIFAR10源碼分析】
上面的代碼中,我們用transforms.Compose([……])組合了一系列的對(duì)image的操作,其中trandforms.ToTensor()
和transforms.Normalize()
都涉及到歸一化操作:
- 原始的cifar10數(shù)據(jù)集是numpy array的形式,其中數(shù)據(jù)范圍是[0,255],pytorch加載時(shí),并沒(méi)有改變數(shù)據(jù)范圍,依舊是[0,255],加載后的數(shù)據(jù)維度是(H, W, C),源碼部分:
__getitem__()
函數(shù)中進(jìn)行transforms操作,進(jìn)行了歸一化:實(shí)際上傳入的transform在__getitem__()
函數(shù)中被調(diào)用,其中transforms.Totensor()
會(huì)將data(也就是image)的維度變成(C,H, W)的形式,并且歸一化到[0.0,1.0];
transforms.Normalize()
會(huì)根據(jù)z = (x-mean) / std 對(duì)數(shù)據(jù)進(jìn)行歸一化,上述代碼中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
是可以將3個(gè)通道單獨(dú)進(jìn)行歸一化,3個(gè)通道可以設(shè)置不同的mean和std,最終數(shù)據(jù)范圍變成[-0.5,+0.5] 。
所以如果通過(guò)pytorch的cifar10加載數(shù)據(jù)集后,針對(duì)traindataset.data,依舊是沒(méi)有進(jìn)行歸一化的;但是比如traindataset[index].data,其中[index]這樣的按下標(biāo)取元素的操作會(huì)直接調(diào)用的__getitem__()函數(shù),此時(shí)的data就是經(jīng)過(guò)了歸一化的。
除traindataset[index]會(huì)隱式自動(dòng)調(diào)用__getitem__()函數(shù)外,還有什么時(shí)候會(huì)調(diào)用這個(gè)函數(shù)呢?畢竟……只有調(diào)用了這個(gè)函數(shù)才會(huì)調(diào)用transforms中的歸一化處理。——答案是與dataloader搭配使用!
torchvision.datasets加載的數(shù)據(jù)集搭配Dataloader使用
torchvision.datasets實(shí)際上是torch.utils.data.Dataset的子類(lèi),那么就能傳入Dataloader中,迭代的按batch-size獲取批量數(shù)據(jù),用于訓(xùn)練或者測(cè)試。其中dataloader加載dataset中的數(shù)據(jù)時(shí),就是用到了其__getitem__()函數(shù),所以用dataloader加載數(shù)據(jù)集,得到的是經(jīng)過(guò)歸一化后的數(shù)據(jù)。
model.train()和model.eval()
我發(fā)現(xiàn)上面的問(wèn)題,是我用dataloader加載了訓(xùn)練數(shù)據(jù)集用于訓(xùn)練resnet18模型,訓(xùn)練過(guò)程中,我訓(xùn)練好并保存后,順便測(cè)試了一下在測(cè)試數(shù)據(jù)集上的準(zhǔn)確度。但是在測(cè)試的過(guò)程中,我沒(méi)有用dataloader加載測(cè)試數(shù)據(jù)集,而是直接用的dataset.data來(lái)進(jìn)行的測(cè)試。并且!由于是并沒(méi)有將model設(shè)置成model.eval()【其實(shí)我設(shè)置了,但是我對(duì)自己很無(wú)語(yǔ),我寫(xiě)的model.eval,忘記加括號(hào)了,無(wú)語(yǔ)嗚嗚】……也就是即便我的測(cè)試數(shù)據(jù)集沒(méi)有經(jīng)過(guò)歸一化,由于模型還是在model.train()模式下,因此模型的BN層會(huì)自己調(diào)整,使得模型性能不受影響,因此在測(cè)試數(shù)據(jù)集上的accuracy達(dá)到了0.86,我就沒(méi)有多想。
后來(lái)我用模型的時(shí)候,設(shè)置了model.eval()后,依舊是直接用的dataset.data(也就是沒(méi)有歸一化),不管是在測(cè)試數(shù)據(jù)集上還是在訓(xùn)練數(shù)據(jù)集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才發(fā)現(xiàn)是歸一化的問(wèn)題。
- 在
model.train()
模式下進(jìn)行預(yù)測(cè)時(shí),PyTorch會(huì)默認(rèn)啟用一些訓(xùn)練相關(guān)的操作,例如Batch Normalization和Dropout,并且模型的參數(shù)是可變的,能夠根據(jù)輸入進(jìn)行調(diào)整。這些操作在訓(xùn)練模式下可以幫助模型更好地適應(yīng)訓(xùn)練數(shù)據(jù),并產(chǎn)生較高的準(zhǔn)確度。 - 在
model.eval()
模式下進(jìn)行預(yù)測(cè)時(shí),PyTorch會(huì)將模型切換到評(píng)估模式,這會(huì)導(dǎo)致一些訓(xùn)練相關(guān)的操作行為發(fā)生變化。具體而言,Batch Normalization層會(huì)使用訓(xùn)練集上的統(tǒng)計(jì)信息進(jìn)行歸一化,而不是使用當(dāng)前批次的統(tǒng)計(jì)信息。因此,如果輸入數(shù)據(jù)沒(méi)有進(jìn)行歸一化,模型在評(píng)估模式下的準(zhǔn)確度可能會(huì)顯著下降。
以下是我沒(méi)有用dataloader加載數(shù)據(jù)集,進(jìn)行預(yù)測(cè)的代碼:
def correctness(model,data,target, device): batchsize = 1000 batch_num = int(len(data) / batchsize) # 對(duì)原始的數(shù)據(jù)進(jìn)行操作 從H.W.C變成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device) # 手動(dòng)歸一化 data = data/255 data = (data - 0.5) / 0.5 # 求一個(gè)batch的correctness def _batch_correctness(i): images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize] predict = model(images).detach().cpu() correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32) return correctness result = np.array([_batch_correctness(i) for i in range(batch_num)]) return result.flatten().sum()/data.shape[0]
我后面用上面的代碼測(cè)試了四種情況:
- model.eval() + 沒(méi)有歸一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手動(dòng)歸一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 沒(méi)有歸一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手動(dòng)歸一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可見(jiàn),在model.eval()模式下,數(shù)據(jù)歸一化對(duì)最終的測(cè)試結(jié)果有很大影響。
到此這篇關(guān)于pytorch加載的cifar10數(shù)據(jù)集,到底有沒(méi)有經(jīng)過(guò)歸一化的文章就介紹到這了,更多相關(guān)pytorch加載cifar10數(shù)據(jù)集內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
- Pytorch搭建簡(jiǎn)單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類(lèi)任務(wù)
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
- pytorch中的模型訓(xùn)練(以CIFAR10數(shù)據(jù)集為例)
- Pytorch使用卷積神經(jīng)網(wǎng)絡(luò)對(duì)CIFAR10圖片進(jìn)行分類(lèi)方式
相關(guān)文章
利用Hyperic調(diào)用Python實(shí)現(xiàn)進(jìn)程守護(hù)
這篇文章主要為大家詳細(xì)介紹了利用Hyperic調(diào)用Python實(shí)現(xiàn)進(jìn)程守護(hù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01聊聊基于pytorch實(shí)現(xiàn)Resnet對(duì)本地?cái)?shù)據(jù)集的訓(xùn)練問(wèn)題
本文項(xiàng)目是使用Resnet模型來(lái)識(shí)別螞蟻和蜜蜂,其一共有三百九十六張的數(shù)據(jù),訓(xùn)練集只有兩百多張(數(shù)據(jù)集很?。?,運(yùn)行十輪后,分別對(duì)訓(xùn)練集和測(cè)試集在每一輪的準(zhǔn)確率,對(duì)pytorch實(shí)現(xiàn)Resnet本地?cái)?shù)據(jù)集的訓(xùn)練感興趣的朋友一起看看吧2022-03-03Python+Selenium實(shí)現(xiàn)在Geoserver批量發(fā)布Mongo矢量數(shù)據(jù)
這篇文章主要為大家詳細(xì)介紹了如何利用Python+Selenium實(shí)現(xiàn)在 Geoserver批量發(fā)布來(lái)自Mongo中的矢量數(shù)據(jù),文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2022-07-07python3自動(dòng)更新緩存類(lèi)的具體使用
本文介紹了使用一個(gè)自動(dòng)更新緩存的Python類(lèi)AutoUpdatingCache,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2025-01-01Python表格處理模塊xlrd在Anaconda中的安裝方法
本文介紹在Anaconda環(huán)境下,安裝Python讀取.xls格式表格文件的庫(kù)xlrd的方法,xlrd是一個(gè)用于讀取Excel文件的Python庫(kù),本文介紹了xlrd庫(kù)的一些主要特點(diǎn)和功能,感興趣的朋友一起看看吧2024-04-04linux環(huán)境下安裝pyramid和新建項(xiàng)目的步驟
這篇文章簡(jiǎn)單介紹了linux環(huán)境下安裝pyramid和新建項(xiàng)目的步驟,大家參考使用2013-11-11詳解Pytorch+PyG實(shí)現(xiàn)GCN過(guò)程示例
這篇文章主要為大家介紹了Pytorch+PyG實(shí)現(xiàn)GCN過(guò)程示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04利用python實(shí)現(xiàn)漢字轉(zhuǎn)拼音的2種方法
這篇文章主要給大家介紹了關(guān)于如何利用python實(shí)現(xiàn)漢字轉(zhuǎn)拼音的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08基于PyQt5自制簡(jiǎn)單的文件內(nèi)容檢索小工具
這篇文章主要為大家詳細(xì)介紹了如何基于PyQt5自制一個(gè)簡(jiǎn)單的文件內(nèi)容檢索小工具,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-05-05