TensorFlow saver指定變量的存取
今天和大家分享一下用TensorFlow的saver存取訓(xùn)練好的模型那點事。
1. 用saver存取變量;
2. 用saver存取指定變量。
用saver存取變量。
話不多說,先上代碼
# coding=utf-8 import os import tensorflow as tf import numpy os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集沒有裝,加這個不顯示那些警告 w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32) b = tf.Variable([[4,5,6]],dtype=tf.float32,) s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32) init = tf.global_variables_initializer() saver =tf.train.Saver() with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "save_net.ckpt")#路徑可以自己定 print("save to path:",save_path)
這里我隨便定義了幾個變量然后進行存操作,運行后,變量w,b,s會被保存下來。保存會生成如下幾個文件:
- cheakpoint
- save_net.ckpt.data-*
- save_net.ckpt.index
- save_net.ckpt.meta
接下來是讀取的代碼
import tensorflow as tf import os import numpy as np os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32) b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32) a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32) saver =tf.train.Saver() with tf.Session() as sess: saver.restore(sess,'save_net.ckpt') print ("weights",sess.run(w)) print ("b",sess.run(b)) print ("s",sess.run(a))
在寫讀取代碼時要注意變量定義的類型、大小和變量的數(shù)量以及順序等要與存的時候一致,不然會報錯。你存的時候順序是w,b,s,取的時候同樣這個順序。存的時候w定義了dtype沒有 定義name,取的時候同樣要這樣,因為TensorFlow存取是按照鍵值對來存取的,所以必須一致。這里變量名,也就是w,s之類可以不同。
如下是我成功讀取的效果
用saver存取指定變量。
在我們做訓(xùn)練時候,有些變量是沒有必要保存的,但是如果直接用tf.train.Saver()。程序會將所有的變量保存下來,這時候我們可以指定保存,只保存我們需要的變量,其他的統(tǒng)統(tǒng)丟掉。
其實很簡單,只需要在上面代碼基礎(chǔ)上稍加修改,只需把tf.train.Saver()替換成如下代碼
program = [] program += [w,b] tf.train.Saver(program)
這樣,程序就只會存w和b了。同樣,讀取程序里面的tf.train.Saver()也要做如上修改。dtype,name之類依舊必須一致。
最后附上最終代碼:
# coding=utf-8 # saver保存變量測試 import os import tensorflow as tf import numpy os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集沒有裝,加這個不顯示那些警告 w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32) b = tf.Variable([[4,5,6]],dtype=tf.float32,) s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32) init = tf.global_variables_initializer() program = [] program += [w, b] saver =tf.train.Saver(program) with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "save_net.ckpt")#路徑可以自己定 print("save to path:",save_path)
#saver提取變量測試 import tensorflow as tf import os import numpy as np os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32) b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32) a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32) program = [] program +=[w,b] saver =tf.train.Saver(program) with tf.Session() as sess: saver.restore(sess,'save_net.ckpt') print ("weights",sess.run(w)) print ("b",sess.run(b)) #print ("s",sess.run(a))
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
關(guān)于django 數(shù)據(jù)庫遷移(migrate)應(yīng)該知道的一些事
今天小編就為大家分享一篇關(guān)于django 數(shù)據(jù)庫遷移(migrate)應(yīng)該知道的一些事,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-05-05Python中DataFrame與內(nèi)置數(shù)據(jù)結(jié)構(gòu)相互轉(zhuǎn)換的實現(xiàn)
pandas?支持我們從?Excel、CSV、數(shù)據(jù)庫等不同數(shù)據(jù)源當(dāng)中讀取數(shù)據(jù),來構(gòu)建?DataFrame。但有時數(shù)據(jù)并不來自這些外部數(shù)據(jù)源,這就涉及到了?DataFrame?和?Python?內(nèi)置數(shù)據(jù)結(jié)構(gòu)之間的相互轉(zhuǎn)換,本文就來和大家詳細(xì)聊聊2023-02-02Python根據(jù)字符串調(diào)用函數(shù)過程解析
這篇文章主要介紹了Python根據(jù)字符串調(diào)用函數(shù)過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-11-11python3 webp轉(zhuǎn)gif格式的實現(xiàn)示例
這篇文章主要介紹了python3 webp轉(zhuǎn)gif格式的實現(xiàn)示例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12Python?Prometheus接口揭秘數(shù)據(jù)科學(xué)新技巧
本篇文章將分享Prometheus?API的基本概念到PromQL查詢語言的應(yīng)用,再到如何通過Python與Prometheus?API進行無縫交互,通過豐富的示例代碼和詳細(xì)的講解,將解鎖使用Python進行實時監(jiān)控的奇妙世界,為讀者打開更廣闊的數(shù)據(jù)分析視野2024-01-01