tensorflow2.0保存和恢復(fù)模型3種方法
方法1:只保存模型的權(quán)重和偏置
這種方法不會(huì)保存整個(gè)網(wǎng)絡(luò)的結(jié)構(gòu),只是保存模型的權(quán)重和偏置,所以在后期恢復(fù)模型之前,必須手動(dòng)創(chuàng)建和之前模型一模一樣的模型,以保證權(quán)重和偏置的維度和保存之前的相同。
tf.keras.model類中的save_weights方法和load_weights方法,參數(shù)解釋我就直接搬運(yùn)官網(wǎng)的內(nèi)容了。
save_weights( filepath, overwrite=True, save_format=None )
Arguments:
filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.
load_weights( filepath, by_name=False )
實(shí)例1:
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, optimizers # step1 加載訓(xùn)練集和測(cè)試集合 mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # step2 創(chuàng)建模型 def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model = create_model() # step3 編譯模型 主要是確定優(yōu)化方法,損失函數(shù)等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step4 模型訓(xùn)練 訓(xùn)練一個(gè)epochs model.fit(x=x_train, y=y_train, epochs=1, ) # step5 模型測(cè)試 loss, acc = model.evaluate(x_test, y_test) print("train model, accuracy:{:5.2f}%".format(100 * acc)) # step6 保存模型的權(quán)重和偏置 model.save_weights('./save_weights/my_save_weights') # step7 刪除模型 del model # step8 重新創(chuàng)建模型 model = create_model() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step9 恢復(fù)權(quán)重 model.load_weights('./save_weights/my_save_weights') # step10 測(cè)試模型 loss, acc = model.evaluate(x_test, y_test) print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
train model, accuracy:96.55%
Restored model, accuracy:96.55%
可以看到在模型的權(quán)重和偏置恢復(fù)之后,在測(cè)試集合上同樣達(dá)到了訓(xùn)練之前相同的準(zhǔn)確率。
方法2:直接保存整個(gè)模型
這種方法會(huì)將網(wǎng)絡(luò)的結(jié)構(gòu),權(quán)重和優(yōu)化器的狀態(tài)等參數(shù)全部保存下來(lái),后期恢復(fù)的時(shí)候就沒必要?jiǎng)?chuàng)建新的網(wǎng)絡(luò)了。
tf.keras.model類中的save方法和load_model方法
save( filepath, overwrite=True, include_optimizer=True, save_format=None )
Arguments:
filepath: String, path to SavedModel or H5 file to save the model.
overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).
實(shí)例2:
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, optimizers # step1 加載訓(xùn)練集和測(cè)試集合 mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # step2 創(chuàng)建模型 def create_model(): return tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model = create_model() # step3 編譯模型 主要是確定優(yōu)化方法,損失函數(shù)等 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # step4 模型訓(xùn)練 訓(xùn)練一個(gè)epochs model.fit(x=x_train, y=y_train, epochs=1, ) # step5 模型測(cè)試 loss, acc = model.evaluate(x_test, y_test) print("train model, accuracy:{:5.2f}%".format(100 * acc)) # step6 保存模型的權(quán)重和偏置 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' # step7 刪除模型 del model # deletes the existing model # step8 恢復(fù)模型 # returns a compiled model # identical to the previous one restored_model = tf.keras.models.load_model('my_model.h5') # step9 測(cè)試模型 loss, acc = restored_model.evaluate(x_test, y_test) print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
train model, accuracy:96.94%
Restored model, accuracy:96.94%
方法3:使用tf.keras.callbacks.ModelCheckpoint方法在訓(xùn)練過程中保存模型
該方法繼承自tf.keras.callbacks類,一般配合mode.fit函數(shù)使用
以上這篇tensorflow2.0保存和恢復(fù)模型3種方法就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python導(dǎo)包的幾種方法(自定義包的生成以及導(dǎo)入詳解)
這篇文章主要介紹了python導(dǎo)包的幾種方法(自定義包的生成以及導(dǎo)入詳解),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-07-07Python3之簡(jiǎn)單搭建自帶服務(wù)器的實(shí)例講解
今天小編就為大家分享一篇Python3之簡(jiǎn)單搭建自帶服務(wù)器的實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-06-06Python采集貓眼兩萬(wàn)條數(shù)據(jù) 對(duì)《無(wú)名之輩》影評(píng)進(jìn)行分析
這篇文章主要給大家介紹了關(guān)于利用Python榮國(guó)采集兩萬(wàn)條貓眼數(shù)據(jù),對(duì)《無(wú)名之輩》影評(píng)進(jìn)行分析的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2018-12-12Python實(shí)現(xiàn)圖書管理系統(tǒng)設(shè)計(jì)
這篇文章主要為大家詳細(xì)介紹了Python實(shí)現(xiàn)圖書管理系統(tǒng)設(shè)計(jì),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-03-03Python操作csv文件之csv.writer()和csv.DictWriter()方法的基本使用
csv文件是一種逗號(hào)分隔的純文本形式存儲(chǔ)的表格數(shù)據(jù),Python內(nèi)置了CSV模塊,可直接通過該模塊實(shí)現(xiàn)csv文件的讀寫操作,下面這篇文章主要給大家介紹了關(guān)于Python操作csv文件之csv.writer()和csv.DictWriter()方法的基本使用,需要的朋友可以參考下2022-09-09python畫圖——實(shí)現(xiàn)在圖上標(biāo)注上具體數(shù)值的方法
今天小編就為大家分享一篇python畫圖——實(shí)現(xiàn)在圖上標(biāo)注上具體數(shù)值的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2019-07-07python爬取代理IP并進(jìn)行有效的IP測(cè)試實(shí)現(xiàn)
這篇文章主要介紹了python爬取代理IP并進(jìn)行有效的IP測(cè)試實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10淺談python 導(dǎo)入模塊和解決文件句柄找不到問題
今天小編就為大家分享一篇淺談python 導(dǎo)入模塊和解決文件句柄找不到問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧2018-12-12