pytorch+sklearn實現(xiàn)數(shù)據(jù)加載的流程
之前在訓(xùn)練網(wǎng)絡(luò)的時候加載數(shù)據(jù)都是稀里糊涂的放進(jìn)去的,也沒有理清楚里面的流程,今天整理一下,加深理解,也方便以后查閱。
pytorch+sklearn實現(xiàn)數(shù)據(jù)加載
epoch & batch_size & iteration
epoch
:1個epoch等于使用訓(xùn)練集中的全部樣本訓(xùn)練一次,通俗的講epoch的值就是整個數(shù)據(jù)集被輪幾次。batch_size
:批大小。在深度學(xué)習(xí)中,一般采用SGD訓(xùn)練,即每次訓(xùn)練在訓(xùn)練集中取batchsize個樣本訓(xùn)練;iteration
:1個iteration等于使用batch_size個樣本訓(xùn)練一次;
優(yōu)化算法——梯度下降
深度學(xué)習(xí)的優(yōu)化算法,說白了就是梯度下降。每次的參數(shù)更新有兩種方式。
Batch gradient descent
第一種,遍歷全部數(shù)據(jù)集
算一次損失函數(shù),然后算函數(shù)對各個參數(shù)的梯度,更新梯度,這稱為批梯度下降(Batch gradient descent)
。
這樣做至少有 2 個好處:其一,由全數(shù)據(jù)集確定的方向能夠更好地代表樣本總體,從而更準(zhǔn)確地朝向極值所在的方向。其二,由于不同權(quán)重的梯度值差別巨大,因此選取一個全局的學(xué)習(xí)率很困難。 Full Batch Learning 可以使用 Rprop 只基于梯度符號并且針對性單獨更新各權(quán)值。
對于更大的數(shù)據(jù)集,以上 2 個好處又變成了 2 個壞處:其一,隨著數(shù)據(jù)集的海量增長和內(nèi)存限制,一次性載入所有的數(shù)據(jù)進(jìn)來變得越來越不可行。其二,以 Rprop 的方式迭代,會由于各個 Batch 之間的采樣差異性,各次梯度修正值相互抵消,無法修正。這才有了后來 RMSProp 的妥協(xié)方案。
Stochastic gradient descent
另一種,每看一個數(shù)據(jù)就算一下?lián)p失函數(shù)
,然后求梯度更新參數(shù),這個稱為隨機梯度下降(Stochastic gradient descent)
。這個方法速度比較快,但是收斂性能不太好,可能在最優(yōu)點附近晃來晃去,達(dá)不到最優(yōu)點。兩次參數(shù)的更新也有可能互相抵消掉,造成目標(biāo)函數(shù)震蕩的比較劇烈。
Mini-batch gradient decent
為了克服兩種方法的缺點,現(xiàn)在一般采用的是一種折中手段,mini-batch gradient decent
,小批的梯度下降,這種方法把數(shù)據(jù)分為若干個批,按批來更新參數(shù),這樣,一個批中的一組數(shù)據(jù)共同決定了本次梯度的方向,下降起來就不容易跑偏,減少了隨機性。另一方面因為批的樣本數(shù)與整個數(shù)據(jù)集相比小了很多,計算量也不是很大。
現(xiàn)在用的優(yōu)化器SGD是stochastic gradient descent的縮寫,但不代表是一個樣本就更新一回,還是基于mini-batch的。
- 批量梯度下降:批量大小=訓(xùn)練集的大小
- 隨機梯度下降:批量大小= 1
- 小批量梯度下降:1 <批量大小<訓(xùn)練集的大小
在小批量梯度下降的情況下,流行的批量大小包括32,64和128個樣本。
再談Batch_Size
在合理范圍內(nèi),增大 Batch_Size 有何好處?
- 內(nèi)存利用率提高了,大矩陣乘法的并行化效率提高。
- 跑完一次 epoch(全數(shù)據(jù)集)所需的迭代次數(shù)減少,對于相同數(shù)據(jù)量的處理速度進(jìn)一步加快。
- 在一定范圍內(nèi),一般來說 Batch_Size 越大,其確定的下降方向越準(zhǔn),引起訓(xùn)練震蕩越小。
盲目增大 Batch_Size 有何壞處?
- 內(nèi)存利用率提高了,但是內(nèi)存容量可能撐不住了。
- 跑完一次 epoch(全數(shù)據(jù)集)所需的迭代次數(shù)減少,要想達(dá)到相同的精度,其所花費的時間大大增加了,從而對參數(shù)的修正也就顯得更加緩慢。
- Batch_Size 增大到一定程度,其確定的下降方向已經(jīng)基本不再變化。
深度學(xué)習(xí)的第一項任務(wù)——數(shù)據(jù)加載
數(shù)據(jù)加載流程——重要
以BCICIV_2a數(shù)據(jù)為例
import mne import numpy as np import torch import torch.nn as nn
class LoadData: def __init__(self,eeg_file_path: str): self.eeg_file_path = eeg_file_path def load_raw_data_gdf(self,file_to_load): self.raw_eeg_subject = mne.io.read_raw_gdf(self.eeg_file_path + '/' + file_to_load) return self def load_raw_data_mat(self,file_to_load): import scipy.io as sio self.raw_eeg_subject = sio.loadmat(self.eeg_file_path + '/' + file_to_load) def get_all_files(self,file_path_extension: str =''): if file_path_extension: return glob.glob(self.eeg_file_path+'/'+file_path_extension) return os.listdir(self.eeg_file_path)
class LoadBCIC(LoadData): '''Subclass of LoadData for loading BCI Competition IV Dataset 2a''' def __init__(self, file_to_load, *args): self.stimcodes=('769','770','771','772') # self.epoched_data={} self.file_to_load = file_to_load self.channels_to_remove = ['EOG-left', 'EOG-central', 'EOG-right'] super(LoadBCIC,self).__init__(*args) def get_epochs(self, tmin=0,tmax=1,baseline=None): self.load_raw_data_gdf(self.file_to_load) raw_data = self.raw_eeg_subject # raw_downsampled = raw_data.copy().resample(sfreq=128) self.fs = raw_data.info.get('sfreq') events, event_ids = mne.events_from_annotations(raw_data) stims =[value for key, value in event_ids.items() if key in self.stimcodes] epochs = mne.Epochs(raw_data, events, event_id=stims, tmin=tmin, tmax=tmax, event_repeated='drop', baseline=baseline, preload=True, proj=False, reject_by_annotation=False) epochs = epochs.drop_channels(self.channels_to_remove) self.y_labels = epochs.events[:, -1] - min(epochs.events[:, -1]) self.x_data = epochs.get_data()*1e6 eeg_data={'x_data':self.x_data, 'y_labels':self.y_labels, 'fs':self.fs} return eeg_data
data_path = "/home/pytorch/LiangXiaohan/MI_Dataverse/BCICIV_2a_gdf" file_to_load = 'A01T.gdf'
'''for BCIC Dataset''' bcic_data = LoadBCIC(file_to_load, data_path) eeg_data = bcic_data.get_epochs() # {'x_data':, 'y_labels':, 'fs':}
X = eeg_data.get('x_data') Y = eeg_data.get('y_labels') Y.shape
from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0) X_train.shape
from sklearn.model_selection import StratifiedKFold train_idx = {} eval_idx = {}
skf = StratifiedKFold(n_splits=4, shuffle=True) i = 0 for train_indices, eval_indices in skf.split(X_train, y_train): train_idx.update({i: train_indices}) eval_idx.update({i: eval_indices}) i += 1 train_idx.get(1).shape
def split_xdata(eeg_data, train_idx, eval_idx): x_train=np.copy(eeg_data[train_idx,:,:]) x_eval=np.copy(eeg_data[eval_idx,:,:]) x_train = torch.from_numpy(x_train).to(torch.float32) x_eval = torch.from_numpy(x_eval).to(torch.float32) return x_train, x_eval
def split_ydata(y_true, train_idx, eval_idx): y_train = np.copy(y_true[train_idx]) y_eval = np.copy(y_true[eval_idx]) y_train = torch.from_numpy(y_train) y_eval = torch.from_numpy(y_eval) return y_train, y_eval
x_train, x_eval = split_xdata(X_train, train_idx.get(1), eval_idx.get(1)) y_train, y_eval = split_ydata(Y_train, train_idx.get(1), eval_idx.get(1)) y_train.shape
from torch.utils.data import Dataset, DataLoader, TensorDataset from tqdm import tqdm def BCICDataLoader(x_train, y_train, batch_size=64, num_workers=2, shuffle=True): data = TensorDataset(x_train, y_train) train_data = DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) return train_data
train_data = BCICDataLoader(x_train, y_train, batch_size=32) for inputs, target in tqdm(train_data): print(target)
到此數(shù)據(jù)就讀出來了?。?!
相關(guān)API解釋
sklearn.model_selection.train_test_split
sklearn.model_selection.StratifiedKFold
torch.utils.data.TensorDataset
https://pytorch.org/docs/stable/data.html?highlight=tensordataset#torch.utils.data.TensorDataset
torch.utils.data.DataLoader
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
參考資料
深度學(xué)習(xí)中的batch、epoch、iteration的含義
神經(jīng)網(wǎng)絡(luò)中Batch和Epoch之間的區(qū)別是什么?
談?wù)勆疃葘W(xué)習(xí)中的 Batch_Size
到此這篇關(guān)于pytorch+sklearn實現(xiàn)數(shù)據(jù)加載的文章就介紹到這了,更多相關(guān)pytorch數(shù)據(jù)加載內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python中使用logging模塊代替print(logging簡明指南)
這篇文章主要介紹了Python中使用logging模塊代替print的好處說明,主旨是logging模塊簡明指南,logging模塊的使用方法介紹,需要的朋友可以參考下2014-07-07Python實現(xiàn)csv文件(點表和線表)轉(zhuǎn)換為shapefile文件的方法
這篇文章主要介紹了Python實現(xiàn)csv文件(點表和線表)轉(zhuǎn)換為shapefile文件的方法,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-10-10使用pickle存儲數(shù)據(jù)dump 和 load實例講解
今天小編就為大家分享一篇使用pickle存儲數(shù)據(jù)dump 和 load實例講解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12django 多數(shù)據(jù)庫及分庫實現(xiàn)方式
這篇文章主要介紹了django 多數(shù)據(jù)庫及分庫實現(xiàn)方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04