淺談keras2 predict和fit_generator的坑
1、使用predict時,必須設(shè)置batch_size,否則效率奇低。
查看keras文檔中,predict函數(shù)原型:
predict(self, x, batch_size=32, verbose=0)
說明:
只使用batch_size=32,也就是說每次將batch_size=32的數(shù)據(jù)通過PCI總線傳到GPU,然后進(jìn)行預(yù)測。在一些問題中,batch_size=32明顯是非常小的。而通過PCI傳數(shù)據(jù)是非常耗時的。
所以,使用的時候會發(fā)現(xiàn)預(yù)測數(shù)據(jù)時效率奇低,其原因就是batch_size太小了。
經(jīng)驗:
使用predict時,必須人為設(shè)置好batch_size,否則PCI總線之間的數(shù)據(jù)傳輸次數(shù)過多,性能會非常低下。
2、fit_generator
說明:keras 中 fit_generator參數(shù)steps_per_epoch已經(jīng)改變含義了,目前的含義是一個epoch分成多少個batch_size。舊版的含義是一個epoch的樣本數(shù)目。
如果說訓(xùn)練樣本樹N=1000,steps_per_epoch = 10,那么相當(dāng)于一個batch_size=100,如果還是按照舊版來設(shè)置,那么相當(dāng)于
batch_size = 1,會性能非常低。
經(jīng)驗:
必須明確fit_generator參數(shù)steps_per_epoch
補充知識:Keras:創(chuàng)建自己的generator(適用于model.fit_generator),解決內(nèi)存問題
為什么要使用model.fit_generator?
在現(xiàn)實的機器學(xué)習(xí)中,訓(xùn)練一個model往往需要數(shù)量巨大的數(shù)據(jù),如果使用fit進(jìn)行數(shù)據(jù)訓(xùn)練,很有可能導(dǎo)致內(nèi)存不夠,無法進(jìn)行訓(xùn)練。
fit_generator的定義如下:
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
其中各項的具體解釋,請參考Keras中文文檔
我們重點關(guān)注的是generator參數(shù):
generator: 一個生成器,或者一個 Sequence (keras.utils.Sequence) 對象的實例, 以在使用多進(jìn)程時避免數(shù)據(jù)的重復(fù)。 生成器的輸出應(yīng)該為以下之一:
一個 (inputs, targets) 元組
一個 (inputs, targets, sample_weights) 元組。
那么,問題來了,如何構(gòu)建這個generator呢?有以下幾種辦法:
自己創(chuàng)建一個generator生成器
自己定義一個 Sequence (keras.utils.Sequence) 對象
使用Keras自帶的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory來生成一個generator
1.自己創(chuàng)建一個generator生成器
使用Keras自帶的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 靈活度不高,只有當(dāng)數(shù)據(jù)集滿足一定格式(例如,按照分類文件夾存放)或者具備一定條件時,使用才使用才較為方便。
此時,自己創(chuàng)建一個generator就很重要了,關(guān)于python的generator是什么原理,怎么使用,就不加贅述,可以查看python的基本語法。
此處,我們用yield來返回數(shù)據(jù)組,標(biāo)簽組,從而使fit_generator可以調(diào)用我們的generator來成批處理數(shù)據(jù)。
具體實現(xiàn)如下:
def myGenerator(batch_size): # loading data X_train,Y_train=load_data(...) # data processing # ................ total_size=X_train.size #batch_size means how many data you want to train one step while 1: for i in range(total_size//batch_size): yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size] return myGenerator
接著你可以調(diào)用該生成器:
self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)
以上這篇淺談keras2 predict和fit_generator的坑就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python使用Selenium WebDriver的入門介紹及安裝教程(最新推薦)
這篇文章主要介紹了Python使用Selenium WebDriver的入門介紹及安裝教程,本文使用環(huán)境為python3.11+win10 64位+firefox瀏覽器,所以本文使用的瀏覽器驅(qū)動是Firefox的geckodriver ,如果你使用的是其他瀏覽器,那么選擇自己對應(yīng)的瀏覽器驅(qū)動程序即可,需要的朋友可以參考下2023-04-04關(guān)于Python?中IndexError:list?assignment?index?out?of?rang
這篇文章主要介紹了Python?中IndexError:list?assignment?index?out?of?range?錯誤解決,概述了兩個常見的列表函數(shù),它們可以幫助我們在替換兩個列表時幫助我們處理?Python?中的索引錯誤,需要的朋友可以參考下2023-05-05Python數(shù)據(jù)庫編程之pymysql詳解
本文主要介紹了Python數(shù)據(jù)庫編程中pymysql,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-05-05利用Python?Matplotlib繪圖并輸出圖像到文件中的方式
這篇文章主要介紹了利用Python?Matplotlib繪圖并輸出圖像到文件中的方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-09-09關(guān)于Python中flask-httpauth庫用法詳解
這篇文章主要介紹了關(guān)于Python中flask-httpauth庫用法詳解,Flask-HTTPAuth是一個?Flask?擴展,它簡化了?HTTP?身份驗證與?Flask?路由的使用,需要的朋友可以參考下2023-04-04