使用tensorflow保存和恢復(fù)模型saver.restore
tensorflow保存和恢復(fù)模型saver.restore
本文只對(duì)一些細(xì)節(jié)點(diǎn)做補(bǔ)充,大體的步驟就不詳述了
保存模型
① 首先我使用的是tensorflow-gpu 1.4.0
② 這個(gè)版本生成的ckpt文件是這樣的:
其中.meta存放的是網(wǎng)絡(luò)模型和所有的變量;
.index 和.data一起存放變量數(shù)據(jù)
-0 -500表示checkpoint點(diǎn)
③ 保存的配置(一定細(xì)看代碼注釋?。?!)
import tensorflow as tf w1 = tf.Variable(變量的初始化, name='w1') w2 = tf.Variable(變量的初始化, name='w2') saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2) # 這里是細(xì)節(jié)部分,可以指定保存的變量,每兩小時(shí)保存最近的5個(gè)模型 sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False)) # 因?yàn)槟P蜎]必要多次保存,所以寫為False
恢復(fù)模型(一定細(xì)看代碼注釋!?。?
代碼:
import tensorflow as tf with tf.Session() as sess: saver = tf.train.import_meta_graph(模型路徑) # 模型路徑中必須指定到具體的模型下如:xx.ckpt-500.meta,且一般來講,所有模型都是一樣的,如果沒有改變模型的條件下。 # 下面的restore就是在當(dāng)前的sess下恢復(fù)了所有的變量 saver.restore(sess,數(shù)據(jù)路徑) # 數(shù)據(jù)路徑也必須指定到具體某個(gè)模型的數(shù)據(jù),但創(chuàng)建這個(gè)路徑的方法很多,比如調(diào)用最后一個(gè)保存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,并且這兩個(gè)是等效的,如果是xx.ckpt-0.data,就是第一個(gè)模型的數(shù)據(jù) print(sess.run('w1:0')) # 這里的w1必須加上:0
tensorflow里的,保存和恢復(fù)模型的方式
重點(diǎn)在于,第一個(gè)文件用于 訓(xùn)練,保存圖meta和訓(xùn)練好的參數(shù)data(后綴),在另一個(gè)文件中導(dǎo)入這個(gè)圖和訓(xùn)練好的參數(shù),用于預(yù)測或者接著訓(xùn)練。
大大減少了另一個(gè)文件里的 重復(fù)
第一種情況
產(chǎn)生變量的代碼和恢復(fù)變量的代碼在同一個(gè)文件時(shí),可以直接如下調(diào)用:
# 建模型 saver = tf.train.Saver() with tf.Session() as sess: # 存模型,注意此處的model是文件名,不是路徑 saver.save(sess, "/tmp/model") with tf.Session() as sess: # 恢復(fù)模型 saver.restore(sess, "/tmp/model")
第二種情況
不想在另一個(gè)文件中,把產(chǎn)生變量的 一大堆代碼重敲一遍,可以直接從保存好的 meta文件和data文件中恢復(fù)出來
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/9/9 20:49 # @Author : ZZL # @File : 保存檢查點(diǎn)文件,并恢復(fù).py import tensorflow as tf # Saving contents and operations. v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") v3 = tf.multiply(v1, v2) vx = tf.Variable(10.0, name="vx") v4 = tf.add(v3, vx, name="v4") saver = tf.train.Saver([vx]) with tf.Session() as sess: with tf.device('/cpu:0'): sess.run(tf.global_variables_initializer()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3}) print(result) print(saver.save(sess, "./model_ex1")) # 該方法返回新創(chuàng)建的檢查點(diǎn)文件的路徑前綴。這個(gè)字符串可以直接傳遞給對(duì)“restore()”的調(diào)用。
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2019/9/9 20:54 # @Author : ZZL # @File : 恢復(fù)文件.py import tensorflow as tf saver = tf.train.import_meta_graph("./model_ex1.meta") sess = tf.Session() saver.restore(sess, "./model_ex1") result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) print(result)
先來個(gè)空?qǐng)D,loaded_graph,在會(huì)話中,導(dǎo)入之前構(gòu)建好的圖的文件 后綴 meta,loader.restore(sess, save_model_path)
在當(dāng)前的loaded_graph中,導(dǎo)入構(gòu)建好的圖和圖上的變量值。
def test_model(): test_features, test_labels = pickle.load(open('preprocess_test.p', mode='rb')) loaded_graph = tf.Graph() # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320> # print( loaded_graph) # print(tf.get_default_graph()) # <tensorflow.python.framework.ops.Graph object at 0x0000017C9A0C0C50> with tf.Session(graph=loaded_graph) as sess: # 讀取模型 loader = tf.train.import_meta_graph(save_model_path + '.meta') print(loader) loader.restore(sess, save_model_path) print(tf.get_default_graph()) # <tensorflow.python.framework.ops.Graph object at 0x0000017CB3702320> # 從已經(jīng)讀入的模型中 獲取tensors loaded_x = loaded_graph.get_tensor_by_name('x:0') loaded_y = loaded_graph.get_tensor_by_name('y:0') loaded_keep_prob = loaded_graph.get_tensor_by_name('keep_prob:0') loaded_logits = loaded_graph.get_tensor_by_name('logits:0') loaded_acc = loaded_graph.get_tensor_by_name('accuracy:0') # 獲取每個(gè)batch的準(zhǔn)確率,再求平均值,這樣可以節(jié)約內(nèi)存 test_batch_acc_total = 0 test_batch_count = 0 for test_feature_batch, test_label_batch in helper.batch_features_labels(test_features, test_labels, batch_size): test_batch_acc_total += sess.run( loaded_acc, feed_dict={loaded_x: test_feature_batch, loaded_y: test_label_batch, loaded_keep_prob: 1.0}) test_batch_count += 1
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
查看keras的默認(rèn)backend實(shí)現(xiàn)方式
這篇文章主要介紹了查看keras的默認(rèn)backend實(shí)現(xiàn)方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06PyQt5 QTableView設(shè)置某一列不可編輯的方法
今天小編就為大家分享一篇PyQt5 QTableView設(shè)置某一列不可編輯的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-06-06對(duì)PyQt5中樹結(jié)構(gòu)的實(shí)現(xiàn)方法詳解
今天小編就為大家分享一篇對(duì)PyQt5中樹結(jié)構(gòu)的實(shí)現(xiàn)方法詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-06-06Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)從字典中提取子集的方法示例
這篇文章主要介紹了Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)從字典中提取子集的方法,涉及Python字典推導(dǎo)式的相關(guān)使用技巧,需要的朋友可以參考下2018-03-03使用Python文件讀寫,自定義分隔符(custom delimiter)
這篇文章主要介紹了使用Python文件讀寫,自定義分隔符(custom delimiter),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-07-07pycharm調(diào)試功能如何實(shí)現(xiàn)跳到循環(huán)的某一步
這篇文章主要介紹了pycharm調(diào)試功能如何實(shí)現(xiàn)跳到循環(huán)的某一步問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-08-08Pytorch精準(zhǔn)記錄函數(shù)運(yùn)行時(shí)間的方法
參考Pytorch官方文檔對(duì)CUDA的描述,GPU的運(yùn)算是異步執(zhí)行的,一般來說,異步計(jì)算的效果對(duì)于調(diào)用者來說是不可見的,異步計(jì)算的后果是,沒有同步的時(shí)間測量是不準(zhǔn)確的,所以本文給大家介紹了Pytorch如何精準(zhǔn)記錄函數(shù)運(yùn)行時(shí)間,需要的朋友可以參考下2024-11-11python 多線程將大文件分開下載后在合并的實(shí)例
今天小編就為大家分享一篇python 多線程將大文件分開下載后在合并的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11