將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作
背景:目前keras框架使用簡(jiǎn)單,很容易上手,深得廣大算法工程師的喜愛,但是當(dāng)部署到客戶端時(shí),可能會(huì)出現(xiàn)各種各樣的bug,甚至不支持使用keras,本文來(lái)解決的是將keras的h5模型轉(zhuǎn)換為客戶端常用的tensorflow的pb模型并使用tensorflow加載pb模型。
h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as osp from keras import backend as K #路徑參數(shù) input_path = 'input path' weight_file = 'weight.h5' weight_file_path = osp.join(input_path,weight_file) output_graph_name = weight_file[:-3] + '.pb' #轉(zhuǎn)換函數(shù) def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True): if osp.exists(output_dir) == False: os.mkdir(output_dir) out_nodes = [] for i in range(len(h5_model.outputs)): out_nodes.append(out_prefix + str(i + 1)) tf.identity(h5_model.output[i],out_prefix + str(i + 1)) sess = K.get_session() from tensorflow.python.framework import graph_util,graph_io init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes) graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False) if log_tensorboard: from tensorflow.python.tools import import_pb_to_tensorboard import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir) #輸出路徑 output_dir = osp.join(os.getcwd(),"trans_model") #加載模型 h5_model = load_model(weight_file_path) h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name) print('model saved')
將轉(zhuǎn)換成的pb模型進(jìn)行加載
load_pb.py import tensorflow as tf from tensorflow.python.platform import gfile def load_pb(pb_file_path): sess = tf.Session() with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') print(sess.run('b:0')) #輸入 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') #輸出 op = sess.graph.get_tensor_by_name('op_to_store:0') #預(yù)測(cè)結(jié)果 ret = sess.run(op, {input_x: 3, input_y: 4}) print(ret)
補(bǔ)充知識(shí):h5模型轉(zhuǎn)化為pb模型,代碼及排坑
我是在實(shí)際工程中要用到tensorflow訓(xùn)練的pb模型,但是訓(xùn)練的代碼是用keras寫的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函數(shù)。
附上h5_to_pb.py(python3)
#*-coding:utf-8-* """ 將keras的.h5的模型文件,轉(zhuǎn)換成TensorFlow的pb文件 """ # ========================================================== from keras.models import load_model import tensorflow as tf import os.path as osp import os from keras import backend #from keras.models import Sequential def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True): """.h5模型文件轉(zhuǎn)換成pb模型文件 Argument: h5_model: str .h5模型文件 output_dir: str pb模型文件保存路徑 model_name: str pb模型文件名稱 out_prefix: str 根據(jù)訓(xùn)練,需要修改 log_tensorboard: bool 是否生成日志文件 Return: pb模型文件 """ if os.path.exists(output_dir) == False: os.mkdir(output_dir) out_nodes = [] for i in range(len(h5_model.outputs)): out_nodes.append(out_prefix + str(i + 1)) tf.identity(h5_model.output[i], out_prefix + str(i + 1)) sess = backend.get_session() from tensorflow.python.framework import graph_util, graph_io # 寫入pb模型文件 init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False) # 輸出日志文件 if log_tensorboard: from tensorflow.python.tools import import_pb_to_tensorboard import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir) if __name__ == '__main__': # .h模型文件路徑參數(shù) input_path = 'D:/CSP' weight_file = 'xingren.h5' weight_file_path = os.path.join(input_path, weight_file) output_graph_name = weight_file[:-3] + '.pb' # pb模型文件輸出輸出路徑 output_dir = osp.join(os.getcwd(),"trans_model") #model.save(xingren.h5) # 加載模型 #h5_model = Sequential() h5_model = load_model(weight_file_path) #h5_model.save(weight_file_path) #h5_model.save('xingren.h5') h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name) print ('Finished')
在運(yùn)行的時(shí)候遇到了下面問(wèn)題:
原因:我們訓(xùn)練模型的時(shí)候用save_weights函數(shù)保存模型,但是這個(gè)函數(shù)只保存了權(quán)重文件,并沒(méi)有又保存模型的參數(shù)。要把save_weights改為save。
下邊是兩個(gè)函數(shù)介紹:
save()保存的模型結(jié)果,它既保持了模型的圖結(jié)構(gòu),又保存了模型的參數(shù)。
save_weights()保存的模型結(jié)果,它只保存了模型的參數(shù),但并沒(méi)有保存模型的圖結(jié)構(gòu)
以上這篇將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
windows下pycharm安裝、創(chuàng)建文件、配置默認(rèn)模板
這篇文章主要為大家詳細(xì)介紹了windows下pycharm安裝、創(chuàng)建文件、配置默認(rèn)模板,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-07-07兩個(gè)命令把 Vim 打造成 Python IDE的方法
這篇文章主要介紹了兩個(gè)命令把 Vim 打造成 Python IDE,需要的朋友可以參考下2016-03-03Python實(shí)現(xiàn)批量將word轉(zhuǎn)換成pdf
這篇文章主要為大家詳細(xì)介紹了如何利用Python實(shí)現(xiàn)批量將word文檔轉(zhuǎn)換成pdf文件,文中的示例代碼簡(jiǎn)潔易懂,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2023-08-08python基礎(chǔ)教程之對(duì)象和類的實(shí)際運(yùn)用
這篇文章主要介紹了python基礎(chǔ)教程之對(duì)象和類的實(shí)際運(yùn)用,本文講解對(duì)象和類的一方法技巧,例如屬性、內(nèi)置方法、self關(guān)鍵字的運(yùn)用等,需要的朋友可以參考下2014-08-08python自制簡(jiǎn)易mysql連接池的實(shí)現(xiàn)示例
本文主要介紹了python自制簡(jiǎn)易mysql連接池的實(shí)現(xiàn)示例,文中通過(guò)示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-11-11Python利用format函數(shù)實(shí)現(xiàn)對(duì)齊打印(左對(duì)齊、右對(duì)齊與居中對(duì)齊)
format是字符串內(nèi)嵌的一個(gè)方法,用于格式化字符串,下面這篇文章主要給大家介紹了關(guān)于Python利用format函數(shù)實(shí)現(xiàn)對(duì)齊打印(左對(duì)齊、右對(duì)齊與居中對(duì)齊)的相關(guān)資料,需要的朋友可以參考下2022-04-04Python基礎(chǔ)之函數(shù)用法實(shí)例詳解
這篇文章主要介紹了Python中函數(shù)用法,包括了函數(shù)的創(chuàng)建、定義、參數(shù)等,需要的朋友可以參考下2014-09-09Python多線程threading join和守護(hù)線程setDeamon原理詳解
這篇文章主要介紹了Python多線程threading join和守護(hù)線程setDeamon原理詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-03-03