tensorflow pb to tflite 精度下降詳解
之前希望在手機(jī)端使用深度模型做OCR,于是嘗試在手機(jī)端部署tensorflow模型,用于圖像分類(lèi)。
思路主要是想使用tflite部署到安卓端,但是在使用tflite的時(shí)候發(fā)現(xiàn)模型的精度大幅度下降,已經(jīng)不能支持業(yè)務(wù)需求了,最后就把OCR模型調(diào)用寫(xiě)在服務(wù)端了,但是精度下降的原因目前也沒(méi)有找到,現(xiàn)在這里記錄一下。
工作思路:
1.訓(xùn)練圖像分類(lèi)模型;2.模型固化成pb;3.由pb轉(zhuǎn)成tflite文件;
但是使用python 的tf interpreter 調(diào)用tflite文件就已經(jīng)出現(xiàn)精度下降的問(wèn)題,android端部署也是一樣。
1.網(wǎng)絡(luò)結(jié)構(gòu)
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf slim = tf.contrib.slim def ttnet(images, num_classes=10, is_training=False, dropout_keep_prob=0.5, prediction_fn=slim.softmax, scope='TtNet'): end_points = {} with tf.variable_scope(scope, 'TtNet', [images, num_classes]): net = slim.conv2d(images, 32, [3, 3], scope='conv1') # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2') net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1') # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1') net = slim.conv2d(net, 64, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') net = slim.conv2d(net, 128, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], 2, scope='pool3') net = slim.conv2d(net, 256, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], 2, scope='pool4') net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2') # net = slim.conv2d(net, 512, [3, 3], scope='conv5') # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5') net = slim.flatten(net) end_points['Flatten'] = net # net = slim.fully_connected(net, 1024, scope='fc3') net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout3') logits = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc4') end_points['Logits'] = logits end_points['Predictions'] = prediction_fn(logits, scope='Predictions') return logits, end_points ttnet.default_image_size = 28 def ttnet_arg_scope(weight_decay=0.0): with slim.arg_scope( [slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay), weights_initializer=tf.truncated_normal_initializer(stddev=0.1), activation_fn=tf.nn.relu) as sc: return sc
基于slim,由于是一個(gè)比較簡(jiǎn)單的分類(lèi)問(wèn)題,網(wǎng)絡(luò)結(jié)構(gòu)也很簡(jiǎn)單,幾個(gè)卷積加池化。
測(cè)試效果是很棒的。真實(shí)樣本測(cè)試集能達(dá)到99%+的準(zhǔn)確率。
2.模型固化,生成pb文件
#coding:utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from nets import nets_factory import cv2 import os import numpy as np from datasets import dataset_factory from preprocessing import preprocessing_factory from tensorflow.python.platform import gfile slim = tf.contrib.slim #todo #support arbitray image size and num_class tf.app.flags.DEFINE_string( 'checkpoint_path', '/tmp/tfmodel/', 'The directory where the model was written to or an absolute path to a ' 'checkpoint file.') tf.app.flags.DEFINE_string( 'model_name', 'inception_v3', 'The name of the architecture to evaluate.') tf.app.flags.DEFINE_string( 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 'as `None`, then the model_name flag is used.') FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer( 'eval_image_size', None, 'Eval image size') tf.app.flags.DEFINE_integer( 'eval_image_height', None, 'Eval image height') tf.app.flags.DEFINE_integer( 'eval_image_width', None, 'Eval image width') tf.app.flags.DEFINE_string( 'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file') FLAGS = tf.app.flags.FLAGS NUM_CLASSES = 37 def main(_): network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=NUM_CLASSES, is_training=False) # pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data') # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name # image_preprocessing_fn = preprocessing_factory.get_preprocessing( # preprocessing_name, # is_training=False) # image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width) # images2 = tf.expand_dims(image, 0) images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data') logits, endpoints = network_fn(images2) with tf.Session() as sess: output = tf.identity(endpoints['Predictions'],name="output_data") with gfile.GFile(FLAGS.export_path, 'wb') as f: f.write(sess.graph_def.SerializeToString()) if __name__ == '__main__': tf.app.run()
3.生成tflite文件
import tensorflow as tf graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb" input_arrays = ["input_data"] output_arrays = ["output_data"] converter = tf.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model)
使用pb文件進(jìn)行測(cè)試,效果正常;使用tflite文件進(jìn)行測(cè)試,精度下降嚴(yán)重。下面附上pb與tflite測(cè)試代碼。
pb測(cè)試代碼
with tf.gfile.GFile(graph_filename, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def) input_node = graph.get_tensor_by_name('import/input_data:0') output_node = graph.get_tensor_by_name('import/output_data:0') with tf.Session() as sess: for image_file in image_files: abs_path = os.path.join(image_folder, image_file) img = cv2.imread(abs_path).astype(np.float32) img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2]))) output_data = sess.run(output_node, feed_dict={input_node: [img]}) index = np.argmax(output_data) label = dict_laebl[index] dst_floder = os.path.join(result_folder, label) if not os.path.exists(dst_floder): os.mkdir(dst_floder) cv2.imwrite(os.path.join(dst_floder, image_file), img) count += 1
tflite測(cè)試代碼
model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite" interpreter = tf.contrib.lite.Interpreter(model_path=model_path) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() for image_file in image_files: abs_path = os.path.join(image_folder,image_file) img = cv2.imread(abs_path).astype(np.float32) img = cv2.resize(img, tuple(input_details[0]['shape'][1:3])) # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) interpreter.set_tensor(input_details[0]['index'], [img]) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) index = np.argmax(output_data) label = dict_laebl[index] dst_floder = os.path.join(result_folder,label) if not os.path.exists(dst_floder): os.mkdir(dst_floder) cv2.imwrite(os.path.join(dst_floder,image_file),img) count+=1
最后也算是繞過(guò)這個(gè)問(wèn)題解決了業(yè)務(wù)需求,后面有空的話(huà),還是會(huì)花時(shí)間研究一下這個(gè)問(wèn)題。
如果有哪個(gè)大佬知道原因,希望不吝賜教。
補(bǔ)充知識(shí):.pb 轉(zhuǎn)tflite代碼,使用量化,減小體積,converter.post_training_quantize = True
import tensorflow as tf path = "/home/python/Downloads/a.pb" # pb文件位置和文件名 inputs = ["input_images"] # 模型文件的輸入節(jié)點(diǎn)名稱(chēng) classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的輸出節(jié)點(diǎn)名稱(chēng) # converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]}) converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images': [1, 320, 320, 3]}) converter.post_training_quantize = True tflite_model = converter.convert() open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)
以上這篇tensorflow pb to tflite 精度下降詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python flashtext文本搜索和替換操作庫(kù)功能使用探索
本文將深入介紹Python flashtext庫(kù),包括其基本用法、功能特性、示例代碼以及實(shí)際應(yīng)用場(chǎng)景,以幫助大家更好地利用這個(gè)有用的工具2024-01-01Python實(shí)現(xiàn)的多叉樹(shù)尋找最短路徑算法示例
這篇文章主要介紹了Python實(shí)現(xiàn)的多叉樹(shù)尋找最短路徑算法,結(jié)合實(shí)例形式分析了Python使用深度優(yōu)先查找獲取多叉樹(shù)最短路徑相關(guān)操作技巧,需要的朋友可以參考下2018-07-07python斯皮爾曼spearman相關(guān)性分析實(shí)例
這篇文章主要為大家介紹了python斯皮爾曼spearman相關(guān)性分析實(shí)例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-02-02python pymysql鏈接數(shù)據(jù)庫(kù)查詢(xún)結(jié)果轉(zhuǎn)為Dataframe實(shí)例
這篇文章主要介紹了python pymysql鏈接數(shù)據(jù)庫(kù)查詢(xún)結(jié)果轉(zhuǎn)為Dataframe實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-06-06對(duì)變量賦值的理解--Pyton中讓兩個(gè)值互換的實(shí)現(xiàn)方法
下面小編就為大家分享一篇Pyton中讓兩個(gè)值互換的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2017-11-11