keras讀取訓(xùn)練好的模型參數(shù)并把參數(shù)賦值給其它模型詳解
介紹
本博文中的代碼,實現(xiàn)的是加載訓(xùn)練好的模型model_halcon_resenet.h5,并把該模型的參數(shù)賦值給兩個不同的新的model。
函數(shù)式模型
官網(wǎng)上給出的調(diào)用一個訓(xùn)練好模型,并輸出任意層的feature。
model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)
但是這有一個問題,就是新的model,如果輸入inputs和訓(xùn)練好的model的inputs大小不同呢?比如我想建立一個輸入是600x600x3的新model,但是訓(xùn)練好的model輸入是200x200x3,而這時我又想調(diào)用訓(xùn)練好模型的卷積核參數(shù),這時該怎么辦呢?
其實想一下,用訓(xùn)練好的模型參數(shù),即使輸入的尺寸不同,但是這些模型參數(shù)仍然可以處理計算,只是輸出的feature map大小不同。那到底怎么賦值呢?其實很簡單
在定義新的model時,新的model層在定義時,需要加上名字,而這個名字就是訓(xùn)練好的模型的每層名字。如下代碼所示:
inputs=Input(shape=(400,500,3)) X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs) X=BatchNormalization(name=“batch_normalization_1”)(X) X=Activation(‘relu',name=“activation_1”)(X)
最后通過以下代碼即可建立一個新的模型并擁有訓(xùn)練好模型的參數(shù):
model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)
源代碼
from keras.models import load_model from keras.preprocessing import image from keras.applications.vgg19 import preprocess_input from keras.models import Model import numpy as np from keras.layers import Conv2D, MaxPooling2D,merge from keras.layers import BatchNormalization,Activation from keras.layers import Input, Dense from PIL import Image import numpy as np import keras from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten,Input from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D from keras.layers import BatchNormalization,Activation from sklearn.model_selection import train_test_split from keras.applications.densenet import DenseNet169, DenseNet121 from keras.applications.inception_resnet_v2 import InceptionResNetV2 from keras.applications.inception_v3 import InceptionV3 from keras.optimizers import SGD from keras import regularizers from keras.models import Model import tensorflow as tf from PIL import Image from keras.callbacks import TensorBoard import os import cv2 from keras import backend as K from model import focal_loss import keras.losses #ReadMe 該代碼是參考fast rcnn系列,先對整幅圖像提取特征feature map,然后從原圖對應(yīng)位置上映射到feature map,并對feature map進行 # 切片,從而提取對應(yīng)某個位置上的特征,并把該特征送進后面的識別網(wǎng)絡(luò)進行分類識別。 keras.losses.focal_loss = focal_loss#這句代碼是為了引入定義的loss base_model=load_model('model_halcon_resenet.h5') base_model.summary() inputs=Input(shape=(400,500,3)) X=Conv2D(32, (3, 3),name="conv2d_1")(inputs) X=BatchNormalization(name="batch_normalization_1")(X) X=Activation('relu',name="activation_1")(X) #第一個殘差模塊 X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X) X_1=BatchNormalization(name="batch_normalization_2")(X_1) X_1= Activation('relu',name="activation_2")(X_1) X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1) X_1 = BatchNormalization(name="batch_normalization_3")(X_1) merge_data = merge([X_1, X], mode='sum',name="merge_1") X = Activation('relu',name="activation_3")(merge_data) #第一個殘差模塊結(jié)束 X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X) X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X) X=BatchNormalization(name="batch_normalization_4")(X) X=Activation('relu',name="activation_4")(X) #第二個殘差模塊 X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X) X_2=BatchNormalization(name="batch_normalization_5")(X_2) X_2= Activation('relu',name="activation_5")(X_2) X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2) X_2 = BatchNormalization(name="batch_normalization_6")(X_2) merge_data = merge([X_2, X], mode='sum',name="merge_2") X = Activation('relu',name="activation_6")(merge_data) #第二個殘差模塊結(jié)束 X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X) X=Conv2D(64, (3, 3),name="conv2d_7")(X) X=BatchNormalization(name="batch_normalization_7")(X) X=Activation('relu',name="activation_7")(X) X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X) #第三個殘差模塊開始 X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X) X_3=BatchNormalization(name="batch_normalization_8")(X_3) X_3= Activation('relu',name="activation_8")(X_3) X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3) X_3 = BatchNormalization(name="batch_normalization_9")(X_3) merge_data = merge([X_3, X], mode='sum',name="merge_3") X = Activation('relu',name="activation_9")(merge_data) #第三個殘差模塊結(jié)束 X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X) X=BatchNormalization(name="batch_normalization_10")(X) X=Activation('relu',name="activation_10")(X) #第四個殘差模塊開始 X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X) X_4=BatchNormalization(name="batch_normalization_11")(X_4) X_4= Activation('relu',name="activation_11")(X_4) X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4) X_4 = BatchNormalization(name="batch_normalization_12")(X_4) merge_data = merge([X_4, X], mode='sum',name="merge_4") X = Activation('relu',name="activation_12")(merge_data) #第四個殘差模塊結(jié)束 X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X) X = Conv2D(64, (3, 3),name="conv2d_13")(X) X = BatchNormalization(name="batch_normalization_13")(X) X = Activation('relu',name="activation_13")(X) #第五個殘差模塊開始 X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X) X_5=BatchNormalization(name="batch_normalization_14")(X_5) X_5= Activation('relu',name="activation_14")(X_5) X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5) X_5 = BatchNormalization(name="batch_normalization_15")(X_5) merge_data = merge([X_5, X], mode='sum',name="merge_5") X = Activation('relu',name="activation_15")(merge_data) #第五個殘差模塊結(jié)束 model=Model(inputs=inputs, outputs=X) model.load_weights('model_halcon_resenet.h5', by_name=True) #讀取指定圖像數(shù)據(jù) image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png' img = image.load_img(image_dir, target_size=(400, 500)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) #利用第一個模型預(yù)測出特征數(shù)據(jù),并對特征數(shù)據(jù)進行切片 feature_map=model.predict(x) T=np.array(feature_map) f_1=T[:,16:21,0:10,:] print(f_1.shape) print(feature_map.shape) #第一個模型沒有問題 #定義第二個模型 inputs_sec=Input(shape=(1,5,10,64)) X_= Flatten(name="flatten_1")(inputs_sec) X_ = Dense(256, activation='relu',name="dense_1")(X_) X_ = Dropout(0.5,name="dropout_1")(X_) predictions = Dense(6, activation='softmax',name="dense_2")(X_) model_sec=Model(inputs=inputs_sec, outputs=predictions) model_sec.load_weights('model_halcon_resenet.h5', by_name=True) #第二個模型定義結(jié)束 model_sec.summary() #開始對整幅圖像進行切片,并記錄坐標位置 pic=cv2.imread(image_dir) cor_list=[] name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue'] font = cv2.FONT_HERSHEY_SIMPLEX for i in range(3): for j in range(5): if(i==2): cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :] data = np.expand_dims(cut_feature, axis=0) result = model_sec.predict(data) print(result) result_data=result[0].tolist() #如果置信度過低,則舍棄 # if(max(result_data)<=0.7): # continue index_num = result_data.index(max(result_data)) name=name_list[index_num] cor_list = [i * 160 + 6, j * 80] # 每個切片數(shù)據(jù),映射到原圖上,檢測框?qū)?yīng)的左上角坐標 x=cor_list[0] y=cor_list[1] cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2) cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1) else: cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :] data = np.expand_dims(cut_feature, axis=0) result = model_sec.predict(data) print(result) result_data = result[0].tolist() #如果置信度過低,則舍棄 # if (max(result_data) <= 0.7): # continue index_num = result_data.index(max(result_data)) name = name_list[index_num] cor_list = [i * 160 + 6, j * 80] # 每個切片數(shù)據(jù),映射到原圖上,檢測框?qū)?yīng)的左上角坐標 x = cor_list[0] y = cor_list[1] cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2) cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1) cv2.imshow('pic',pic) cv2.waitKey(0) cv2.destroyAllWindows() # data= np.expand_dims(f_1, axis=0) # result=model_sec.predict(data) # print(result) #第二個模型可以完全預(yù)測,沒有問題
補充知識:加載訓(xùn)練好的模型參數(shù),但是權(quán)重一直變化
變量初始化會導(dǎo)致權(quán)重發(fā)生變化,去掉就好了。
以上這篇keras讀取訓(xùn)練好的模型參數(shù)并把參數(shù)賦值給其它模型詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python實現(xiàn)層次性數(shù)據(jù)和閉包性質(zhì)
這篇文章主要介紹了如何利用Python實現(xiàn)層次性數(shù)據(jù)和閉包性質(zhì),文中的示例代碼講解詳細,對我們學(xué)習(xí)Python有一定幫助,需要的可以了解一下2022-05-05Python list列表中刪除多個重復(fù)元素操作示例
這篇文章主要介紹了Python list列表中刪除多個重復(fù)元素操作,結(jié)合實例形式分析了Python刪除list列表重復(fù)元素的相關(guān)操作技巧與注意事項,需要的朋友可以參考下2019-02-02基于django micro搭建網(wǎng)站實現(xiàn)加水印功能
這篇文章主要介紹了基于django micro搭建網(wǎng)站實現(xiàn)加水印功能,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-05-05