欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作

 更新時(shí)間:2020年05月25日 10:28:51   作者:mishidemudong  
這篇文章主要介紹了將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

背景:目前keras框架使用簡(jiǎn)單,很容易上手,深得廣大算法工程師的喜愛,但是當(dāng)部署到客戶端時(shí),可能會(huì)出現(xiàn)各種各樣的bug,甚至不支持使用keras,本文來(lái)解決的是將keras的h5模型轉(zhuǎn)換為客戶端常用的tensorflow的pb模型并使用tensorflow加載pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路徑參數(shù)
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#轉(zhuǎn)換函數(shù)
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#輸出路徑
output_dir = osp.join(os.getcwd(),"trans_model")
#加載模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

將轉(zhuǎn)換成的pb模型進(jìn)行加載

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #輸入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #輸出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #預(yù)測(cè)結(jié)果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)
 

補(bǔ)充知識(shí):h5模型轉(zhuǎn)化為pb模型,代碼及排坑

我是在實(shí)際工程中要用到tensorflow訓(xùn)練的pb模型,但是訓(xùn)練的代碼是用keras寫的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函數(shù)。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
將keras的.h5的模型文件,轉(zhuǎn)換成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件轉(zhuǎn)換成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路徑
    model_name: str
      pb模型文件名稱
    out_prefix: str
      根據(jù)訓(xùn)練,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 寫入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 輸出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路徑參數(shù)
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件輸出輸出路徑
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加載模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在運(yùn)行的時(shí)候遇到了下面問(wèn)題:

原因:我們訓(xùn)練模型的時(shí)候用save_weights函數(shù)保存模型,但是這個(gè)函數(shù)只保存了權(quán)重文件,并沒(méi)有又保存模型的參數(shù)。要把save_weights改為save。

下邊是兩個(gè)函數(shù)介紹:

save()保存的模型結(jié)果,它既保持了模型的圖結(jié)構(gòu),又保存了模型的參數(shù)。

save_weights()保存的模型結(jié)果,它只保存了模型的參數(shù),但并沒(méi)有保存模型的圖結(jié)構(gòu)

以上這篇將keras的h5模型轉(zhuǎn)換為tensorflow的pb模型操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論