keras使用Sequence類調(diào)用大規(guī)模數(shù)據(jù)集進行訓(xùn)練的實現(xiàn)
使用Keras如果要使用大規(guī)模數(shù)據(jù)集對網(wǎng)絡(luò)進行訓(xùn)練,就沒辦法先加載進內(nèi)存再從內(nèi)存直接傳到顯存了,除了使用Sequence類以外,還可以使用迭代器去生成數(shù)據(jù),但迭代器無法在fit_generation里開啟多進程,會影響數(shù)據(jù)的讀取和預(yù)處理效率,在本文中就不在敘述了,有需要的可以另外去百度。
下面是我所使用的代碼
class SequenceData(Sequence):
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
#返回長度,通過len(<你的實例>)調(diào)用
def __len__(self):
return self.L - self.batch_size
#即通過索引獲取a[0],a[1]這種
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})
def data_generation(self, batch_datas):
#預(yù)處理操作
return img1s,img2s,audios,labels
然后在代碼里通過fit_generation函數(shù)調(diào)用并訓(xùn)練
這里要注意,use_multiprocessing參數(shù)是是否開啟多進程,由于python的多線程不是真的多線程,所以多進程還是會獲得比較客觀的加速,但不支持windows,windows下python無法使用多進程。
D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
epochs=2, workers=20, #callbacks=[checkpoint],
use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))
同樣的,也可以在測試的時候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
補充知識:keras數(shù)據(jù)自動生成器,繼承keras.utils.Sequence,結(jié)合fit_generator實現(xiàn)節(jié)約內(nèi)存訓(xùn)練
我就廢話不多說了,大家還是直接看代碼吧~
#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
class DataGenerator(keras.utils.Sequence):
def __init__(self, datas, batch_size=1, shuffle=True):
self.batch_size = batch_size
self.datas = datas
self.indexes = np.arange(len(self.datas))
self.shuffle = shuffle
def __len__(self):
#計算每一個epoch的迭代次數(shù)
return math.ceil(len(self.datas) / float(self.batch_size))
def __getitem__(self, index):
#生成每個batch數(shù)據(jù),這里就根據(jù)自己對數(shù)據(jù)的讀取方式進行發(fā)揮了
# 生成batch_size個索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根據(jù)索引獲取datas集合中的數(shù)據(jù)
batch_datas = [self.datas[k] for k in batch_indexs]
# 生成數(shù)據(jù)
X, y = self.data_generation(batch_datas)
return X, y
def on_epoch_end(self):
#在每一次epoch結(jié)束是否需要進行一次隨機,重新隨機一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def data_generation(self, batch_datas):
images = []
labels = []
# 生成數(shù)據(jù)
for i, data in enumerate(batch_datas):
#x_train數(shù)據(jù)
image = cv2.imread(data)
image = list(image)
images.append(image)
#y_train數(shù)據(jù)
right = data.rfind("\\",0)
left = data.rfind("\\",0,right)+1
class_name = data[left:right]
if class_name=="dog":
labels.append([0,1])
else:
labels.append([1,0])
#如果為多輸出模型,Y的格式要變一下,外層list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
return np.array(images), np.array(labels)
# 讀取樣本名稱,然后根據(jù)樣本名稱去讀取數(shù)據(jù)
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
file_path = os.path.join("D:/xxx", file)
if os.path.isdir(file_path):
class_num = class_num + 1
for sub_file in os.listdir(file_path):
train_datas.append(os.path.join(file_path, sub_file))
# 數(shù)據(jù)生成器
training_generator = DataGenerator(train_datas)
#構(gòu)建網(wǎng)絡(luò)
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
以上這篇keras使用Sequence類調(diào)用大規(guī)模數(shù)據(jù)集進行訓(xùn)練的實現(xiàn)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python中ASCII碼字符與int之間的轉(zhuǎn)換方法
今天小編就為大家分享一篇python中ASCII碼字符與int之間的轉(zhuǎn)換方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07
Python高級數(shù)據(jù)分析之pandas和matplotlib繪圖
Matplotlib是一個強大的Python繪圖和數(shù)據(jù)可視化的工具包,下面這篇文章主要給大家介紹了關(guān)于Python高級數(shù)據(jù)分析之pandas和matplotlib繪圖的相關(guān)資料,文中通過示例代碼介紹的非常詳細,需要的朋友可以參考下2022-05-05
用python結(jié)合jieba和wordcloud實現(xiàn)詞云效果
詞云,顧名思義就是很多個單詞,然后通過出現(xiàn)的頻率或者比重之類的標準匯聚成一個云朵的樣子嘛,其實呢現(xiàn)在網(wǎng)上已經(jīng)有很多能自動生成詞云的工具了,比如Wordle,Tagxedo等等,Python也能實現(xiàn)這樣的效果,我們通過jieba庫和wordcloud庫也能十分輕松的完成詞云的構(gòu)建2017-09-09
Python應(yīng)用開發(fā)之實現(xiàn)串口通信
在嵌入式開發(fā)中我們經(jīng)常會用到串口,串口通信簡單,使用起來方便,且適用場景多。本文為大家準備了Python實現(xiàn)串口通信的示例代碼,需要的可以參考一下2022-11-11

