tensorflow 保存模型和取出中間權(quán)重例子
下面代碼的功能是先訓(xùn)練一個(gè)簡(jiǎn)單的模型,然后保存模型,同時(shí)保存到一個(gè)pb文件當(dāng)中,后續(xù)可以從pd文件里讀取權(quán)重值。
import tensorflow as tf import numpy as np import os import h5py import pickle from tensorflow.python.framework import graph_util from tensorflow.python.platform import gfile #設(shè)置使用指定GPU os.environ['CUDA_VISIBLE_DEVICES'] = '1' #下面這段代碼是在訓(xùn)練好之后將所有的權(quán)重名字和權(quán)重值羅列出來(lái),訓(xùn)練的時(shí)候需要注釋掉 reader = tf.train.NewCheckpointReader('./model.ckpt-100') variables = reader.get_variable_to_shape_map() for ele in variables: print(ele) print(reader.get_tensor(ele)) x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False#設(shè)成True去訓(xùn)練模型 train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b)) graph_def = tf.get_default_graph().as_graph_def() #通過(guò)修改下面的函數(shù),個(gè)人覺(jué)得理論上能夠?qū)崿F(xiàn)修改權(quán)重,但是很復(fù)雜,如果哪位有好辦法,歡迎指教 output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable']) with tf.gfile.FastGFile('./test.pb', 'wb') as f: f.write(output_graph_def.SerializeToString()) with tf.Session() as sess: #對(duì)應(yīng)最后一部分的寫,這里能夠?qū)?duì)應(yīng)的變量取出來(lái) with gfile.FastGFile('./test.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) res = tf.import_graph_def(graph_def, return_elements=['Variable:0']) print(sess.run(res)) print(sess.run(graph_def))
以上這篇tensorflow 保存模型和取出中間權(quán)重例子就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python連接達(dá)夢(mèng)數(shù)據(jù)庫(kù)的實(shí)現(xiàn)示例
本文主要介紹了Python連接達(dá)夢(mèng)數(shù)據(jù)庫(kù)的實(shí)現(xiàn)示例,dmPython是DM提供的依據(jù)Python DB API version 2.0中API使用規(guī)定而開(kāi)發(fā)的數(shù)據(jù)庫(kù)訪問(wèn)接口,使Python應(yīng)用程序能夠?qū)M數(shù)據(jù)庫(kù)進(jìn)行訪問(wèn)2023-12-12python實(shí)例小練習(xí)之Turtle繪制南方的雪花
Turtle庫(kù)是Python語(yǔ)言中一個(gè)很流行的繪制圖像的函數(shù)庫(kù),想象一個(gè)小烏龜,在一個(gè)橫軸為x、縱軸為y的坐標(biāo)系原點(diǎn),(0,0)位置開(kāi)始,它根據(jù)一組函數(shù)指令的控制,在這個(gè)平面坐標(biāo)系中移動(dòng),從而在它爬行的路徑上繪制了圖形2021-09-09使用python如何將數(shù)據(jù)集劃分為訓(xùn)練集、驗(yàn)證集和測(cè)試集
這篇文章主要介紹了使用python如何將數(shù)據(jù)集劃分為訓(xùn)練集、驗(yàn)證集和測(cè)試集問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-09-09Python實(shí)現(xiàn)wav和pcm的轉(zhuǎn)換方式
這篇文章主要介紹了Python實(shí)現(xiàn)wav和pcm的轉(zhuǎn)換方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-05-05Python for循環(huán)搭配else常見(jiàn)問(wèn)題解決
這篇文章主要介紹了Python for循環(huán)搭配else常見(jiàn)問(wèn)題解決,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-02-02python使用pandas處理大數(shù)據(jù)節(jié)省內(nèi)存技巧(推薦)
這篇文章主要介紹了python使用pandas處理大數(shù)據(jù)節(jié)省內(nèi)存技巧,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-05-05Python實(shí)現(xiàn)免費(fèi)音樂(lè)下載器
本文主要為大家介紹了通過(guò)Python實(shí)現(xiàn)的免費(fèi)音樂(lè)下載器,文中的示例代碼講解詳細(xì),對(duì)我們的學(xué)習(xí)或工作有一定的幫助,需要的小伙伴可以學(xué)習(xí)一下2021-12-12