tensorflow如何繼續(xù)訓練之前保存的模型實例
一:需重定義神經(jīng)網(wǎng)絡繼續(xù)訓練的方法
1.訓練代碼
import numpy as np import tensorflow as tf x_data=np.random.rand(100).astype(np.float32) y_data=x_data*0.1+0.3 weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w") biases=tf.Variable(tf.zeros([1]),name="b") y=weight*x_data+biases loss=tf.reduce_mean(tf.square(y-y_data)) #loss optimizer=tf.train.GradientDescentOptimizer(0.5) train=optimizer.minimize(loss) init=tf.global_variables_initializer() sess=tf.Session() sess.run(init) saver=tf.train.Saver(max_to_keep=0) for step in range(10): sess.run(train) saver.save(sess,"./save_mode",global_step=step) #保存 print("當前進行:",step)
第一次訓練截圖:
2.恢復上一次的訓練
import numpy as np import tensorflow as tf sess=tf.Session() saver=tf.train.import_meta_graph(r'save_mode-9.meta') saver.restore(sess,tf.train.latest_checkpoint(r'./')) print(sess.run("w:0"),sess.run("b:0")) graph=tf.get_default_graph() weight=graph.get_tensor_by_name("w:0") biases=graph.get_tensor_by_name("b:0") x_data=np.random.rand(100).astype(np.float32) y_data=x_data*0.1+0.3 y=weight*x_data+biases loss=tf.reduce_mean(tf.square(y-y_data)) optimizer=tf.train.GradientDescentOptimizer(0.5) train=optimizer.minimize(loss) saver=tf.train.Saver(max_to_keep=0) for step in range(10): sess.run(train) saver.save(sess,r"./save_new_mode",global_step=step) print("當前進行:",step," ",sess.run(weight),sess.run(biases))
使用上次保存下的數(shù)據(jù)進行繼續(xù)訓練和保存:
#最后要提一下的是:
checkpoint文件
meta保存了TensorFlow計算圖的結構信息
datat保存每個變量的取值
index保存了 表
加載restore時的文件路徑名是以checkpoint文件中的“model_checkpoint_path”值決定的
這個方法需要重新定義神經(jīng)網(wǎng)絡
二:不需要重新定義神經(jīng)網(wǎng)絡的方法:
在上面訓練的代碼中加入:tf.add_to_collection("name",參數(shù))
import numpy as np import tensorflow as tf x_data=np.random.rand(100).astype(np.float32) y_data=x_data*0.1+0.3 weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w") biases=tf.Variable(tf.zeros([1]),name="b") y=weight*x_data+biases loss=tf.reduce_mean(tf.square(y-y_data)) optimizer=tf.train.GradientDescentOptimizer(0.5) train=optimizer.minimize(loss) tf.add_to_collection("new_way",train) init=tf.global_variables_initializer() sess=tf.Session() sess.run(init) saver=tf.train.Saver(max_to_keep=0) for step in range(10): sess.run(train) saver.save(sess,"./save_mode",global_step=step) print("當前進行:",step)
在下面的載入代碼中加入:tf.get_collection("name"),就可以直接使用了
import numpy as np import tensorflow as tf sess=tf.Session() saver=tf.train.import_meta_graph(r'save_mode-9.meta') saver.restore(sess,tf.train.latest_checkpoint(r'./')) print(sess.run("w:0"),sess.run("b:0")) graph=tf.get_default_graph() weight=graph.get_tensor_by_name("w:0") biases=graph.get_tensor_by_name("b:0") y=tf.get_collection("new_way")[0] saver=tf.train.Saver(max_to_keep=0) for step in range(10): sess.run(y) saver.save(sess,r"./save_new_mode",global_step=step) print("當前進行:",step," ",sess.run(weight),sess.run(biases))
總的來說,下面這種方法好像是要便利一些
以上這篇tensorflow如何繼續(xù)訓練之前保存的模型實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
使用Python3 poplib模塊刪除服務器多天前的郵件實現(xiàn)代碼
這篇文章主要介紹了使用Python3 poplib模塊刪除多天前的郵件的實現(xiàn)代碼,代碼簡單易懂,非常不錯,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04在Python的Django框架的視圖中使用Session的方法
這篇文章主要介紹了在Python的Django框架的視圖中使用Session的方法,包括相關的設置測試Cookies的方法,需要的朋友可以參考下2015-07-07Python實現(xiàn)方便使用的級聯(lián)進度信息實例
這篇文章主要介紹了Python實現(xiàn)方便使用的級聯(lián)進度信息,實例分析了Python顯示級聯(lián)進度信息的相關技巧,非常具有實用價值,需要的朋友可以參考下2015-05-05在Django的視圖中使用數(shù)據(jù)庫查詢的方法
這篇文章主要介紹了在Django的視圖中使用數(shù)據(jù)庫查詢的方法,是Python的Django框架使用的基礎操作,需要的朋友可以參考下2015-07-07python 如何用map()函數(shù)創(chuàng)建多線程任務
這篇文章主要介紹了python 使用map()函數(shù)創(chuàng)建多線程任務的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2021-04-04