tensorflow 保存模型和取出中間權(quán)重例子
更新時間:2020年01月24日 09:39:07 作者:binqiang2wang
今天小編就為大家分享一篇tensorflow 保存模型和取出中間權(quán)重例子,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
下面代碼的功能是先訓練一個簡單的模型,然后保存模型,同時保存到一個pb文件當中,后續(xù)可以從pd文件里讀取權(quán)重值。
import tensorflow as tf import numpy as np import os import h5py import pickle from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile #設(shè)置使用指定GPU os.environ['CUDA_VISIBLE_DEVICES'] = '1' #下面這段代碼是在訓練好之后將所有的權(quán)重名字和權(quán)重值羅列出來,訓練的時候需要注釋掉 reader = tf.train.NewCheckpointReader('./model.ckpt-100') variables = reader.get_variable_to_shape_map() for ele in variables: print(ele) print(reader.get_tensor(ele)) x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False#設(shè)成True去訓練模型 train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) graph_def = tf.get_default_graph().as_graph_def() #通過修改下面的函數(shù),個人覺得理論上能夠?qū)崿F(xiàn)修改權(quán)重,但是很復雜,如果哪位有好辦法,歡迎指教 output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable']) with tf.gfile.FastGFile('./test.pb', 'wb') as f: f.write(output_graph_def.SerializeToString()) with tf.Session() as sess: #對應最后一部分的寫,這里能夠?qū)淖兞咳〕鰜? with gfile.FastGFile('./test.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) res = tf.import_graph_def(graph_def, return_elements=['Variable:0']) print(sess.run(res)) print(sess.run(graph_def))
以上這篇tensorflow 保存模型和取出中間權(quán)重例子就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python連接達夢數(shù)據(jù)庫的實現(xiàn)示例
本文主要介紹了Python連接達夢數(shù)據(jù)庫的實現(xiàn)示例,dmPython是DM提供的依據(jù)Python DB API version 2.0中API使用規(guī)定而開發(fā)的數(shù)據(jù)庫訪問接口,使Python應用程序能夠?qū)M數(shù)據(jù)庫進行訪問2023-12-12使用python如何將數(shù)據(jù)集劃分為訓練集、驗證集和測試集
這篇文章主要介紹了使用python如何將數(shù)據(jù)集劃分為訓練集、驗證集和測試集問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教2023-09-09Python實現(xiàn)wav和pcm的轉(zhuǎn)換方式
這篇文章主要介紹了Python實現(xiàn)wav和pcm的轉(zhuǎn)換方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-05-05Python for循環(huán)搭配else常見問題解決
這篇文章主要介紹了Python for循環(huán)搭配else常見問題解決,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-02-02python使用pandas處理大數(shù)據(jù)節(jié)省內(nèi)存技巧(推薦)
這篇文章主要介紹了python使用pandas處理大數(shù)據(jù)節(jié)省內(nèi)存技巧,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-05-05