欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

淺談keras保存模型中的save()和save_weights()區(qū)別

 更新時(shí)間:2020年05月21日 14:05:09   作者:木盞  
這篇文章主要介紹了淺談keras保存模型中的save()和save_weights()區(qū)別,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧

今天做了一個(gè)關(guān)于keras保存模型的實(shí)驗(yàn),希望有助于大家了解keras保存模型的區(qū)別。

我們知道keras的模型一般保存為后綴名為h5的文件,比如final_model.h5。同樣是h5文件用save()和save_weight()保存效果是不一樣的。

我們用宇宙最通用的數(shù)據(jù)集MNIST來做這個(gè)實(shí)驗(yàn),首先設(shè)計(jì)一個(gè)兩層全連接網(wǎng)絡(luò):

inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)

然后,導(dǎo)入MNIST數(shù)據(jù)訓(xùn)練,分別用兩種方式保存模型,在這里我還把未訓(xùn)練的模型也保存下來,如下:

from keras.models import Model
from keras.layers import Input, Dense
from keras.datasets import mnist
from keras.utils import np_utils
 
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train=x_train.reshape(x_train.shape[0],-1)/255.0
x_test=x_test.reshape(x_test.shape[0],-1)/255.0
y_train=np_utils.to_categorical(y_train,num_classes=10)
y_test=np_utils.to_categorical(y_test,num_classes=10)
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
 
model.save('m1.h5')
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=10)
#loss,accuracy=model.evaluate(x_test,y_test)
 
model.save('m2.h5')
model.save_weights('m3.h5')

如上可見,我一共保存了m1.h5, m2.h5, m3.h5 這三個(gè)h5文件。那么,我們來看看這三個(gè)玩意兒有什么區(qū)別。首先,看看大?。?/p>

m2表示save()保存的模型結(jié)果,它既保持了模型的圖結(jié)構(gòu),又保存了模型的參數(shù)。所以它的size最大的。

m1表示save()保存的訓(xùn)練前的模型結(jié)果,它保存了模型的圖結(jié)構(gòu),但應(yīng)該沒有保存模型的初始化參數(shù),所以它的size要比m2小很多。

m3表示save_weights()保存的模型結(jié)果,它只保存了模型的參數(shù),但并沒有保存模型的圖結(jié)構(gòu)。所以它的size也要比m2小很多。

通過可視化工具,我們發(fā)現(xiàn):(打開m1和m2均可以顯示出以下結(jié)構(gòu))

而打開m3的時(shí)候,可視化工具報(bào)錯(cuò)了。由此可以論證, save_weights()是不含有模型結(jié)構(gòu)信息的。

加載模型

兩種不同方法保存的模型文件也需要用不同的加載方法。

from keras.models import load_model
 
model = load_model('m1.h5')
#model = load_model('m2.h5')
#model = load_model('m3.h5')
model.summary()

只有加載m3.h5的時(shí)候,這段代碼才會報(bào)錯(cuò)。其他輸出如下:

可見,由save()保存下來的h5文件才可以直接通過load_model()打開!

那么,我們保存下來的參數(shù)(m3.h5)該怎么打開呢?

這就稍微復(fù)雜一點(diǎn)了,因?yàn)閙3不含有模型結(jié)構(gòu)信息,所以我們需要把模型結(jié)構(gòu)再描述一遍才可以加載m3,如下:

from keras.models import Model
from keras.layers import Input, Dense
 
inputs = Input(shape=(784, ))
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
 
model = Model(inputs=inputs, outputs=y)
model.load_weights('m3.h5')

以上把m3換成m1和m2也是沒有問題的!可見,save()保存的模型除了占用內(nèi)存大一點(diǎn)以外,其他的優(yōu)點(diǎn)太明顯了。所以,在不怎么缺硬盤空間的情況下,還是建議大家多用save()來存。

注意!如果要load_weights(),必須保證你描述的有參數(shù)計(jì)算結(jié)構(gòu)與h5文件中完全一致!什么叫有參數(shù)計(jì)算結(jié)構(gòu)呢?就是有參數(shù)坑,直接填進(jìn)去就行了。我們把上面的非參數(shù)結(jié)構(gòu)換了一下,發(fā)現(xiàn)h5文件依然可以加載成功,比如將softmax換成relu,依然不影響加載。

對于keras的save()和save_weights(),完全沒問題了吧

以上這篇淺談keras保存模型中的save()和save_weights()區(qū)別就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評論