淺談tensorflow模型保存為pb的各種姿勢
一,直接保存pb
1, 首先我們當(dāng)然可以直接在tensorflow訓(xùn)練中直接保存為pb為格式,保存pb的好處就是使用場景是實(shí)現(xiàn)創(chuàng)建模型與使用模型的解耦,使得創(chuàng)建模型與使用模型的解耦,使得前向推導(dǎo)inference代碼統(tǒng)一。另外的好處就是保存為pb的時候,模型的變量會變成固定的,導(dǎo)致模型的大小會大大減小。
這里稍稍解釋下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括計算圖,數(shù)據(jù)流,以及相關(guān)的變量和輸入輸出
主要使用tf.SavedModelBuilder來完成這個工作,并且可以把多個計算圖保存到一個pb文件中,如果有多個MetaGraph,那么只會保留第一個MetaGraph的版本號。
保持pb的文件代碼:
import tensorflow as tf import os from tensorflow.python.framework import graph_util pb_file_path = os.getcwd() with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 這里的輸出需要加上name屬性 op = tf.add(xy, b, name='op_to_store') sess.run(tf.global_variables_initializer()) # convert_variables_to_constants 需要指定output_node_names,list(),可以多個 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) # 測試 OP feed_dict = {x: 10, y: 3} print(sess.run(op, feed_dict)) # 寫入序列化的 PB 文件 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f: f.write(constant_graph.SerializeToString()) # 輸出 # INFO:tensorflow:Froze 1 variables. # Converted 1 variables to const ops. # 31
其實(shí)主要是:
# convert_variables_to_constants 需要指定output_node_names,list(),可以多個 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 寫入序列化的 PB 文件 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f: f.write(constant_graph.SerializeToString())
1.1 加載測試代碼
from tensorflow.python.platform import gfile sess = tf.Session() with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') # 導(dǎo)入計算圖 # 需要有一個初始化的過程 sess.run(tf.global_variables_initializer()) # 需要先復(fù)原變量 print(sess.run('b:0')) # 1 # 輸入 input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, feed_dict={input_x: 5, input_y: 5}) print(ret) # 輸出 26
2,第二種就是采用上述的那API來進(jìn)行保存
import tensorflow as tf import os from tensorflow.python.framework import graph_util pb_file_path = os.getcwd() with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') xy = tf.multiply(x, y) # 這里的輸出需要加上name屬性 op = tf.add(xy, b, name='op_to_store') sess.run(tf.global_variables_initializer()) # convert_variables_to_constants 需要指定output_node_names,list(),可以多個 constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) # 測試 OP feed_dict = {x: 10, y: 3} print(sess.run(op, feed_dict)) # 寫入序列化的 PB 文件 with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f: f.write(constant_graph.SerializeToString()) # INFO:tensorflow:Froze 1 variables. # Converted 1 variables to const ops. # 31 # 官網(wǎng)有誤,寫成了 saved_model_builder builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel') # 構(gòu)造模型保存的內(nèi)容,指定要保存的 session,特定的 tag, # 輸入輸出信息字典,額外的信息 builder.add_meta_graph_and_variables(sess, ['cpu_server_1']) # 添加第二個 MetaGraphDef #with tf.Session(graph=tf.Graph()) as sess: # ... # builder.add_meta_graph([tag_constants.SERVING]) #... builder.save() # 保存 PB 模型
核心就是采用了:
# 官網(wǎng)有誤,寫成了 saved_model_builder builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel') # 構(gòu)造模型保存的內(nèi)容,指定要保存的 session,特定的 tag, # 輸入輸出信息字典,額外的信息 builder.add_meta_graph_and_variables(sess, ['cpu_server_1'])
2.1 對應(yīng)的測試代碼為:
with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel') sess.run(tf.global_variables_initializer()) input_x = sess.graph.get_tensor_by_name('x:0') input_y = sess.graph.get_tensor_by_name('y:0') op = sess.graph.get_tensor_by_name('op_to_store:0') ret = sess.run(op, feed_dict={input_x: 5, input_y: 5}) print(ret) # 只需要指定要恢復(fù)模型的 session,模型的 tag,模型的保存路徑即可,使用起來更加簡單
這樣和之前的導(dǎo)入pb模型一樣,也是要知道tensor的name,那么如何在不知道tensor name的情況下使用呢,給add_meta_graph_and_variables方法傳入第三個參數(shù),signature_def_map即可。
二,從ckpt進(jìn)行加載
使用tf.train.saver()保持模型的時候會產(chǎn)生多個文件,會把計算圖的結(jié)構(gòu)和圖上參數(shù)取值分成了不同文件存儲,這種方法是在TensorFlow中最常用的保存方式:
import tensorflow as tf # 聲明兩個變量 v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部變量 saver = tf.train.Saver() # 聲明tf.train.Saver類用于保存模型 with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run(v1)) # 打印v1、v2的值一會讀取之后對比 print("v2:", sess.run(v2)) saver_path = saver.save(sess, "save/model.ckpt") # 將模型保存到save/model.ckpt文件 print("Model saved in file:", saver_path)
checkpoint是檢查點(diǎn)的文件,文件保存了一個目錄下所有的模型文件列表
model.ckpt.meta文件保存了Tensorflow計算圖的結(jié)果,可以理解為神經(jīng)網(wǎng)絡(luò)的網(wǎng)絡(luò)結(jié)構(gòu),該文件可以被tf.train.import_meta_graph加載到當(dāng)前默認(rèn)的圖來使用
ckpt.data是保存模型中每個變量的取值
方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型結(jié)構(gòu),將計算圖中的變量取值以常量的形式保存
ckpt轉(zhuǎn)換pb格式過程如下:
1,通過傳入ckpt模型的路徑得到模型的圖和變量數(shù)據(jù)
2,通過import_meta_graph導(dǎo)入模型中的圖
3,通過saver.restore從模型中恢復(fù)圖中各個變量的數(shù)據(jù)
4,通過graph_util.convert_variables_to_constants將模型持久化
import tensorflow as tf from tensorflow.python.framework import graph_util from tensorflow.pyton.platform import gfile def freeze_graph(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路徑 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態(tài)是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑 # 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是原模型中存在的節(jié)點(diǎn) output_node_names = "InceptionV3/Logits/SpatialSqueeze" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 獲得默認(rèn)的圖 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當(dāng)前的圖 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù) output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 sess=sess, input_graph_def=input_graph_def,# 等于:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多個輸出節(jié)點(diǎn),以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn) # for op in graph.get_operations(): # print(op.name, op.values())
函數(shù)freeze_graph中,最重要的就是指定輸出節(jié)點(diǎn)的名稱,這個節(jié)點(diǎn)名稱是原模型存在的結(jié)點(diǎn),注意節(jié)點(diǎn)名稱與張量名稱的區(qū)別:
如:“input:0”是張量的名稱,而“input”表示的是節(jié)點(diǎn)的名稱
源碼中通過graph = tf.get_default_graph()獲得默認(rèn)圖,這個圖就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢復(fù)的圖,因此就必須執(zhí)行tf.train.import_meta_graph,再執(zhí)行tf.get_default_graph()
1.2 一個小工具
tensorflow打印pb模型的所有節(jié)點(diǎn)
from tensorflow.python.framework import tensor_util from google.protobuf import text_format import tensorflow as tf from tensorflow.python.platform import gfile from tensorflow.python.framework import tensor_util pb_path = './model.pb' with tf.Session() as sess: with gfile.FastGFile(pb_path,'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def,name='') for i,n in enumerate(graph_def.node): print("Name of the node -%s"%n.name) tensorflow打印ckpt的所有節(jié)點(diǎn) from tensorflow.python import pywrap_tensorflow checkpoint_path = './_checkpoint/hed.ckpt-130' reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map: print("tensor_name:",key)
方法二,除了上述辦法外還有一種是需要通過源碼的,這樣既可以得到輸出節(jié)點(diǎn),還可以自定義輸入節(jié)點(diǎn)。
import tensorflow as tf def model(input): net = tf.layers.conv2d(input,filters=32,kernel_size=3) net = tf.layers.batch_normalization(net,fused=False) net = tf.layers.separable_conv2d(net,32,3) net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output') return net input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image') output_node_names = 'head_neck_count/BiasAdd' ckpt = ckpt_path pb = pb_path with tf.Session() as sess: model1 = model(input_node) sess.run(tf.global_variables_initializer()) output_node_names = 'output/BiasAdd' input_graph_def = tf.get_default_graph().as_graph_def() output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(',')) with tf.gfile.GFile(pb,'wb') as f: f.write(output_graph_def.SerializeToString())
注意:
節(jié)點(diǎn)名稱和張量名稱區(qū)別
類似于output是節(jié)點(diǎn)名稱
類似于output:0是張量名稱
方法三,其實(shí)是方法一的延伸可以配合tensorflow自帶的一些工具來進(jìn)行完成
freeze_graph
總共有11個參數(shù),一個個介紹下(必選: 表示必須有值;可選: 表示可以為空):
1、input_graph:(必選)模型文件,可以是二進(jìn)制的pb文件,或文本的meta文件,用input_binary來指定區(qū)分(見下面說明)
2、input_saver:(可選)Saver解析器。保存模型和權(quán)限時,Saver也可以自身序列化保存,以便在加載時應(yīng)用合適的版本。主要用于版本不兼容時使用??梢詾榭?,為空時用當(dāng)前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進(jìn)制,為false時,input_graph為文件。默認(rèn)False
4、input_checkpoint:(必選)檢查點(diǎn)數(shù)據(jù)文件。訓(xùn)練時,給Saver用于保存權(quán)重、偏置等變量值。這時用于模型恢復(fù)變量值。
5、output_node_names:(必選)輸出節(jié)點(diǎn)的名字,有多個時用逗號分開。用于指定輸出節(jié)點(diǎn),將沒有在輸出線上的其它節(jié)點(diǎn)剔除。
6、restore_op_name:(可選)從模型恢復(fù)節(jié)點(diǎn)的名字。升級版中已棄用。默認(rèn):save/restore_all
7、filename_tensor_name:(可選)已棄用。默認(rèn):save/Const:0
8、output_graph:(必選)用來保存整合后的模型輸出文件。
9、clear_devices:(可選),默認(rèn)True。指定是否清除訓(xùn)練時節(jié)點(diǎn)指定的運(yùn)算設(shè)備(如cpu、gpu、tpu。cpu是默認(rèn))
10、initializer_nodes:(可選)默認(rèn)空。權(quán)限加載后,可通過此參數(shù)來指定需要初始化的節(jié)點(diǎn),用逗號分隔多個節(jié)點(diǎn)名字。
11、variable_names_blacklist:(可先)默認(rèn)空。變量黑名單,用于指定不用恢復(fù)值的變量,用逗號分隔多個變量名字。
所以還是建議選擇方法三
導(dǎo)出pb后的測試代碼如下:下圖是比較完成的測試代碼與導(dǎo)出代碼。
# -*-coding: utf-8 -*- """ @Project: tensorflow_models_nets @File : convert_pb.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2018-08-29 17:46:50 @info : -通過傳入 CKPT 模型的路徑得到模型的圖和變量數(shù)據(jù) -通過 import_meta_graph 導(dǎo)入模型中的圖 -通過 saver.restore 從模型中恢復(fù)圖中各個變量的數(shù)據(jù) -通過 graph_util.convert_variables_to_constants 將模型持久化 """ import tensorflow as tf from create_tf_record import * from tensorflow.python.framework import graph_util resize_height = 299 # 指定圖片高度 resize_width = 299 # 指定圖片寬度 depths = 3 def freeze_graph_test(pb_path, image_path): ''' :param pb_path:pb文件的路徑 :param image_path:測試圖片的路徑 :return: ''' with tf.Graph().as_default(): output_graph_def = tf.GraphDef() with open(pb_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 定義輸入的張量名稱,對應(yīng)網(wǎng)絡(luò)結(jié)構(gòu)的輸入張量 # input:0作為輸入圖像,keep_prob:0作為dropout的參數(shù),測試時值為1,is_training:0訓(xùn)練參數(shù) input_image_tensor = sess.graph.get_tensor_by_name("input:0") input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0") input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0") # 定義輸出的張量名稱 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0") # 讀取測試圖片 im=read_image(image_path,resize_height,resize_width,normalization=True) im=im[np.newaxis,:] # 測試讀出來的模型是否正確,注意這里傳入的是輸出和輸入節(jié)點(diǎn)的tensor的名字,不是操作節(jié)點(diǎn)的名字 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False}) out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im, input_keep_prob_tensor:1.0, input_is_training_tensor:False}) print("out:{}".format(out)) score = tf.nn.softmax(out, name='pre') class_id = tf.argmax(score, 1) print "pre class_id:{}".format(sess.run(class_id)) def freeze_graph(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路徑 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態(tài)是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑 # 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是原模型中存在的節(jié)點(diǎn) output_node_names = "InceptionV3/Logits/SpatialSqueeze" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù) output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 sess=sess, input_graph_def=sess.graph_def,# 等于:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多個輸出節(jié)點(diǎn),以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn) # for op in sess.graph.get_operations(): # print(op.name, op.values()) def freeze_graph2(input_checkpoint,output_graph): ''' :param input_checkpoint: :param output_graph: PB模型保存路徑 :return: ''' # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態(tài)是否可用 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑 # 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是原模型中存在的節(jié)點(diǎn) output_node_names = "InceptionV3/Logits/SpatialSqueeze" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) graph = tf.get_default_graph() # 獲得默認(rèn)的圖 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當(dāng)前的圖 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù) output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定 sess=sess, input_graph_def=input_graph_def,# 等于:sess.graph_def output_node_names=output_node_names.split(","))# 如果有多個輸出節(jié)點(diǎn),以逗號隔開 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化輸出 print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn) # for op in graph.get_operations(): # print(op.name, op.values()) if __name__ == '__main__': # 輸入ckpt模型路徑 input_checkpoint='models/model.ckpt-10000' # 輸出pb模型的路徑 out_pb_path="models/pb/frozen_model.pb" # 調(diào)用freeze_graph將ckpt轉(zhuǎn)為pb freeze_graph(input_checkpoint,out_pb_path) # 測試pb模型 image_path = 'test_image/animal.jpg' freeze_graph_test(pb_path=out_pb_path, image_path=image_path)
以上這篇淺談tensorflow模型保存為pb的各種姿勢就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python多進(jìn)程同步Lock、Semaphore、Event實(shí)例
這篇文章主要介紹了Python多進(jìn)程同步Lock、Semaphore、Event實(shí)例,Lock用來避免訪問沖突、Semaphore用來控制對共享資源的訪問數(shù)量、Event用來實(shí)現(xiàn)進(jìn)程間同步通信,需要的朋友可以參考下2014-11-11python實(shí)現(xiàn)簡單中文詞頻統(tǒng)計示例
本篇文章主要介紹了python實(shí)現(xiàn)簡單中文詞頻統(tǒng)計示例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-11-11vscode 遠(yuǎn)程調(diào)試python的方法
本篇文章主要介紹了vscode 遠(yuǎn)程調(diào)試python的方法,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-12-12python中l(wèi)ist*n生成多維數(shù)組與for循環(huán)生成多維數(shù)組的區(qū)別說明
這篇文章主要介紹了python中l(wèi)ist*n生成多維數(shù)組與for循環(huán)生成多維數(shù)組的區(qū)別說明,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2022-05-05Python使用scrapy采集時偽裝成HTTP/1.1的方法
這篇文章主要介紹了Python使用scrapy采集時偽裝成HTTP/1.1的方法,實(shí)例分析了scrapy采集的使用技巧,非常具有實(shí)用價值,需要的朋友可以參考下2015-04-04Pytorch中關(guān)于model.eval()的作用及分析
這篇文章主要介紹了Pytorch中關(guān)于model.eval()的作用及分析,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-02-02python實(shí)現(xiàn)進(jìn)度條的多種實(shí)現(xiàn)
這篇文章主要介紹了python實(shí)現(xiàn)進(jìn)度條的多種實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04