keras讀取訓(xùn)練好的模型參數(shù)并把參數(shù)賦值給其它模型詳解
介紹
本博文中的代碼,實(shí)現(xiàn)的是加載訓(xùn)練好的模型model_halcon_resenet.h5,并把該模型的參數(shù)賦值給兩個(gè)不同的新的model。
函數(shù)式模型
官網(wǎng)上給出的調(diào)用一個(gè)訓(xùn)練好模型,并輸出任意層的feature。
model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool').output)
但是這有一個(gè)問題,就是新的model,如果輸入inputs和訓(xùn)練好的model的inputs大小不同呢?比如我想建立一個(gè)輸入是600x600x3的新model,但是訓(xùn)練好的model輸入是200x200x3,而這時(shí)我又想調(diào)用訓(xùn)練好模型的卷積核參數(shù),這時(shí)該怎么辦呢?
其實(shí)想一下,用訓(xùn)練好的模型參數(shù),即使輸入的尺寸不同,但是這些模型參數(shù)仍然可以處理計(jì)算,只是輸出的feature map大小不同。那到底怎么賦值呢?其實(shí)很簡單
在定義新的model時(shí),新的model層在定義時(shí),需要加上名字,而這個(gè)名字就是訓(xùn)練好的模型的每層名字。如下代碼所示:
inputs=Input(shape=(400,500,3)) X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs) X=BatchNormalization(name=“batch_normalization_1”)(X) X=Activation(‘relu',name=“activation_1”)(X)
最后通過以下代碼即可建立一個(gè)新的模型并擁有訓(xùn)練好模型的參數(shù):
model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5', by_name=True)
源代碼
from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses
#ReadMe 該代碼是參考fast rcnn系列,先對(duì)整幅圖像提取特征feature map,然后從原圖對(duì)應(yīng)位置上映射到feature map,并對(duì)feature map進(jìn)行
# 切片,從而提取對(duì)應(yīng)某個(gè)位置上的特征,并把該特征送進(jìn)后面的識(shí)別網(wǎng)絡(luò)進(jìn)行分類識(shí)別。
keras.losses.focal_loss = focal_loss#這句代碼是為了引入定義的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()
inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一個(gè)殘差模塊
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一個(gè)殘差模塊結(jié)束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二個(gè)殘差模塊
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二個(gè)殘差模塊結(jié)束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三個(gè)殘差模塊開始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三個(gè)殘差模塊結(jié)束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四個(gè)殘差模塊開始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四個(gè)殘差模塊結(jié)束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五個(gè)殘差模塊開始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五個(gè)殘差模塊結(jié)束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#讀取指定圖像數(shù)據(jù)
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一個(gè)模型預(yù)測出特征數(shù)據(jù),并對(duì)特征數(shù)據(jù)進(jìn)行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一個(gè)模型沒有問題
#定義第二個(gè)模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二個(gè)模型定義結(jié)束
model_sec.summary()
#開始對(duì)整幅圖像進(jìn)行切片,并記錄坐標(biāo)位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
for j in range(5):
if(i==2):
cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
data = np.expand_dims(cut_feature, axis=0)
result = model_sec.predict(data)
print(result)
result_data=result[0].tolist()
#如果置信度過低,則舍棄
# if(max(result_data)<=0.7):
# continue
index_num = result_data.index(max(result_data))
name=name_list[index_num]
cor_list = [i * 160 + 6, j * 80] # 每個(gè)切片數(shù)據(jù),映射到原圖上,檢測框?qū)?yīng)的左上角坐標(biāo)
x=cor_list[0]
y=cor_list[1]
cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
else:
cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
data = np.expand_dims(cut_feature, axis=0)
result = model_sec.predict(data)
print(result)
result_data = result[0].tolist()
#如果置信度過低,則舍棄
# if (max(result_data) <= 0.7):
# continue
index_num = result_data.index(max(result_data))
name = name_list[index_num]
cor_list = [i * 160 + 6, j * 80] # 每個(gè)切片數(shù)據(jù),映射到原圖上,檢測框?qū)?yīng)的左上角坐標(biāo)
x = cor_list[0]
y = cor_list[1]
cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二個(gè)模型可以完全預(yù)測,沒有問題
補(bǔ)充知識(shí):加載訓(xùn)練好的模型參數(shù),但是權(quán)重一直變化

變量初始化會(huì)導(dǎo)致權(quán)重發(fā)生變化,去掉就好了。
以上這篇keras讀取訓(xùn)練好的模型參數(shù)并把參數(shù)賦值給其它模型詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
基于Python實(shí)現(xiàn)層次性數(shù)據(jù)和閉包性質(zhì)
這篇文章主要介紹了如何利用Python實(shí)現(xiàn)層次性數(shù)據(jù)和閉包性質(zhì),文中的示例代碼講解詳細(xì),對(duì)我們學(xué)習(xí)Python有一定幫助,需要的可以了解一下2022-05-05
Python list列表中刪除多個(gè)重復(fù)元素操作示例
這篇文章主要介紹了Python list列表中刪除多個(gè)重復(fù)元素操作,結(jié)合實(shí)例形式分析了Python刪除list列表重復(fù)元素的相關(guān)操作技巧與注意事項(xiàng),需要的朋友可以參考下2019-02-02
基于django micro搭建網(wǎng)站實(shí)現(xiàn)加水印功能
這篇文章主要介紹了基于django micro搭建網(wǎng)站實(shí)現(xiàn)加水印功能,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-05-05

