keras導(dǎo)入weights方式
keras源碼engine中toplogy.py定義了加載權(quán)重的函數(shù):
load_weights(self, filepath, by_name=False)
其中默認(rèn)by_name為False,這時候加載權(quán)重按照網(wǎng)絡(luò)拓撲結(jié)構(gòu)加載,適合直接使用keras中自帶的網(wǎng)絡(luò)模型,如VGG16
VGG19/resnet50等,源碼描述如下:
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
若將by_name改為True則加載權(quán)重按照layer的name進行,layer的name相同時加載權(quán)重,適合用于改變了
模型的相關(guān)結(jié)構(gòu)或增加了節(jié)點但利用了原網(wǎng)絡(luò)的主體結(jié)構(gòu)情況下使用,源碼描述如下:
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
在進行邊緣檢測時,利用VGG網(wǎng)絡(luò)的主體結(jié)構(gòu),網(wǎng)絡(luò)中增加反卷積層,這時加載權(quán)重應(yīng)該使用
model.load_weights(filepath,by_name=True)
補充知識:Keras下實現(xiàn)mnist手寫數(shù)字
之前一直在用tensorflow,被同學(xué)推薦來用keras了,把之前文檔中的mnist手寫數(shù)字?jǐn)?shù)據(jù)集拿來練手,
代碼如下。
import struct
import numpy as np
import os
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
def load_mnist(path, kind):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
return images, labels
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
#define models
model = Sequential()
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax'))
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0]
print('Training accuracy: %.2f%%' % (train_acc * 100))
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0]
print('Test accuracy: %.2f%%' % (test_acc * 100))
訓(xùn)練結(jié)果如下:
Epoch 45/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323 Epoch 46/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358 Epoch 47/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347 Epoch 48/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350 Epoch 49/50 42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359 Epoch 50/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346 Training accuracy: 94.11% Test accuracy: 93.61%
以上這篇keras導(dǎo)入weights方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
在Python中調(diào)用Ping命令,批量IP的方法
今天小編就為大家分享一篇在Python中調(diào)用Ping命令,批量IP的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01
PyQt教程之自定義組件Switch?Button的實現(xiàn)
這篇文章主要為大家詳細介紹了PyQt中如何實現(xiàn)自定義組件Switch?Button,文中的示例代碼簡潔易懂,具有一定的學(xué)習(xí)價值,感興趣的可以了解一下2023-05-05
python dataframe astype 字段類型轉(zhuǎn)換方法
下面小編就為大家分享一篇python dataframe astype 字段類型轉(zhuǎn)換方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04
對Python subprocess.Popen子進程管道阻塞詳解
今天小編就為大家分享一篇對Python subprocess.Popen子進程管道阻塞詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10

