TensorFlow實(shí)現(xiàn)模型斷點(diǎn)訓(xùn)練,checkpoint模型載入方式
深度學(xué)習(xí)中,模型訓(xùn)練一般都需要很長(zhǎng)的時(shí)間,由于很多原因,導(dǎo)致模型中斷訓(xùn)練,下面介紹繼續(xù)斷點(diǎn)訓(xùn)練的方法。
方法一:載入模型時(shí),不必指定迭代次數(shù),一般默認(rèn)最新
# 保存模型 saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型 # 開(kāi)啟會(huì)話 with tf.Session() as sess: # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000)) sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state('./log/') # 注意此處是checkpoint存在的目錄,千萬(wàn)不要寫(xiě)成‘./log' if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) # 自動(dòng)恢復(fù)model_checkpoint_path保存模型一般是最新 print("Model restored...") else: print('No Model')
方法二:載入時(shí),指定想要載入模型的迭代次數(shù)
需要到Log文件夾中,查看當(dāng)前迭代的次數(shù),如下:此時(shí)為111000次。
# 保存模型 saver = tf.train.Saver(max_to_keep=1) # 開(kāi)啟會(huì)話 with tf.Session() as sess: saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000)) sess.run(tf.global_variables_initializer())
載入模型后,會(huì)繼續(xù)端點(diǎn)處的變量繼續(xù)訓(xùn)練,那么是否可以減小剩余的需要的迭代次數(shù)?
模型斷點(diǎn)訓(xùn)練效果展示:
訓(xùn)練到167000次后,載入模型重新訓(xùn)練。設(shè)置迭代次數(shù)為10000次,(d_step=1000)。原始設(shè)置的迭代的次數(shù)為1000000,已經(jīng)訓(xùn)練了167000次。
Model restored... Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713 Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298 Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354 Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371 Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707 Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368 Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039 Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496 Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996 Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727 done!
保存的圖片仍然從頭開(kāi)始編號(hào),會(huì)覆蓋掉之前的圖片。
以前對(duì)應(yīng)編號(hào)的采樣圖片為:
若有朋友有高見(jiàn),還請(qǐng)不吝賜教。
補(bǔ)充知識(shí):tensorflow加載訓(xùn)練好的模型及參數(shù)(讀取checkpoint)
checkpoint 保存路徑
model_path下存有包含多個(gè)迭代次數(shù)的模型
1.獲取最新保存的模型
即上圖中的model-9400
import tensorflow as tf graph=tf.get_default_graph() # 獲取當(dāng)前圖 sess=tf.Session() sess.run(tf.global_variables_initializer()) checkpoint_file=tf.train.latest_checkpoint(model_path) saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess,checkpoint_file)
2.獲取某個(gè)迭代次數(shù)的模型
比如上圖中的model-9200
import tensorflow as tf graph=tf.get_default_graph() # 獲取當(dāng)前圖 sess=tf.Session() sess.run(tf.global_variables_initializer()) checkpoint_file=os.path.join(model_path,'model-9200') saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess,checkpoint_file)
獲取變量值
## 得到當(dāng)前圖中所有變量的名稱 tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] # 查看所有變量 print(tensor_name_list) # 獲取input_x和input_y的變量值 input_x = graph.get_operation_by_name("input_x").outputs[0] input_y = graph.get_operation_by_name("input_y").outputs[0]
以上這篇TensorFlow實(shí)現(xiàn)模型斷點(diǎn)訓(xùn)練,checkpoint模型載入方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python Web框架之Django框架cookie和session用法分析
這篇文章主要介紹了Python Web框架之Django框架cookie和session用法,結(jié)合實(shí)例形式分析了Django框架cookie和session的常見(jiàn)使用技巧與操作注意事項(xiàng),需要的朋友可以參考下2019-08-08如何使用PyCharm將代碼上傳到GitHub上(圖文詳解)
這篇文章主要介紹了如何使用PyCharm將代碼上傳到GitHub上(圖文詳解),文中通過(guò)圖文介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04python庫(kù)pycryptodom加密技術(shù)探索(公鑰加密私鑰加密)
這篇文章主要為大家介紹了python庫(kù)pycryptodom加密技術(shù)探索(公鑰加密私鑰加密),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01Python GUI編程學(xué)習(xí)筆記之tkinter事件綁定操作詳解
這篇文章主要介紹了Python GUI編程學(xué)習(xí)筆記之tkinter事件綁定操作,結(jié)合實(shí)例形式分析了Python GUI編程tkinter事件綁定常見(jiàn)操作技巧與使用注意事項(xiàng),需要的朋友可以參考下2020-03-03Python爬蟲(chóng)DOTA排行榜爬取實(shí)例(分享)
下面小編就為大家?guī)?lái)一篇Python爬蟲(chóng)DOTA排行榜爬取實(shí)例(分享)。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-06-06