tensorflow模型繼續(xù)訓(xùn)練 fineturn實(shí)例
解決tensoflow如何在已訓(xùn)練模型上繼續(xù)訓(xùn)練fineturn的問題。
訓(xùn)練代碼
任務(wù)描述: x = 3.0, y = 100.0, 運(yùn)算公式 x×W+b = y,求 W和b的最優(yōu)解。
# -*- coding: utf-8 -*-) import tensorflow as tf # 聲明占位變量x、y x = tf.placeholder("float", shape=[None, 1]) y = tf.placeholder("float", [None, 1]) # 聲明變量 W = tf.Variable(tf.zeros([1, 1]),name='w') b = tf.Variable(tf.zeros([1]),name='b') # 操作 result = tf.matmul(x, W) + b # 損失函數(shù) lost = tf.reduce_sum(tf.pow((result - y), 2)) # 優(yōu)化 train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost) with tf.Session() as sess: # 初始化變量 sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(max_to_keep=3) # 這里x、y給固定的值 x_s = [[3.0]] y_s = [[100.0]] step = 0 while (True): step += 1 feed = {x: x_s, y: y_s} # 通過sess.run執(zhí)行優(yōu)化 sess.run(train_step, feed_dict=feed) if step % 1000 == 0: print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed)) if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3: print '' # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed)) print 'final result of {0} = {1}(目標(biāo)值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b)) print '' print("模型保存的W值 : %f" % sess.run(W)) print("模型保存的b : %f" % sess.run(b)) break saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型
訓(xùn)練完成之后生成模型文件:
訓(xùn)練輸出:
step: 1000, loss: 4.89526428282e-08 step: 2000, loss: 4.89526428282e-08 step: 3000, loss: 4.89526428282e-08 step: 4000, loss: 4.89526428282e-08 step: 5000, loss: 4.89526428282e-08 final result of x×W+b = [[99.99978]](目標(biāo)值是100.0) 模型保存的W值 : 29.999931 模型保存的b : 9.999982
保存在模型中的W值是 29.999931,b是 9.999982。
以下代碼從保存的模型中恢復(fù)出訓(xùn)練狀態(tài),繼續(xù)訓(xùn)練
任務(wù)描述: x = 3.0, y = 200.0, 運(yùn)算公式 x×W+b = y,從上次訓(xùn)練的模型中恢復(fù)出訓(xùn)練參數(shù),繼續(xù)訓(xùn)練,求 W和b的最優(yōu)解。
# -*- coding: utf-8 -*-) import tensorflow as tf # 聲明占位變量x、y x = tf.placeholder("float", shape=[None, 1]) y = tf.placeholder("float", [None, 1]) with tf.Session() as sess: # 初始化變量 sess.run(tf.global_variables_initializer()) # saver = tf.train.Saver(max_to_keep=3) saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加載模型圖結(jié)構(gòu) saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢復(fù)數(shù)據(jù) # 從保存模型中恢復(fù)變量 graph = tf.get_default_graph() W = graph.get_tensor_by_name("w:0") b = graph.get_tensor_by_name("b:0") print("從保存的模型中恢復(fù)出來的W值 : %f" % sess.run("w:0")) print("從保存的模型中恢復(fù)出來的b值 : %f" % sess.run("b:0")) # 操作 result = tf.matmul(x, W) + b # 損失函數(shù) lost = tf.reduce_sum(tf.pow((result - y), 2)) # 優(yōu)化 train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost) # 這里x、y給固定的值 x_s = [[3.0]] y_s = [[200.0]] step = 0 while (True): step += 1 feed = {x: x_s, y: y_s} # 通過sess.run執(zhí)行優(yōu)化 sess.run(train_step, feed_dict=feed) if step % 1000 == 0: print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed)) if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3: print '' # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed)) print 'final result of {0} = {1}(目標(biāo)值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b)) print("模型保存的W值 : %f" % sess.run(W)) print("模型保存的b : %f" % sess.run(b)) break saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型
訓(xùn)練輸出:
從保存的模型中恢復(fù)出來的W值 : 29.999931 從保存的模型中恢復(fù)出來的b值 : 9.999982 step: 1000, loss: 1.95810571313e-07 step: 2000, loss: 1.95810571313e-07 step: 3000, loss: 1.95810571313e-07 step: 4000, loss: 1.95810571313e-07 step: 5000, loss: 1.95810571313e-07 final result of x×W+b = [[199.99956]](目標(biāo)值是200.0) 模型保存的W值 : 59.999866 模型保存的b : 19.999958
從保存的模型中恢復(fù)出來的W值是 29.999931,b是 9.999982,跟模型保存的值一致,說明加載成功。
總結(jié)
從頭開始訓(xùn)練一個模型,需要通過 tf.train.Saver創(chuàng)建一個保存器,完成之后使用save方法保存模型到本地:
saver = tf.train.Saver(max_to_keep=3) …… saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型
在訓(xùn)練好的模型上繼續(xù)訓(xùn)練,fineturn一個模型,可以使用tf.train.import_meta_graph方法加載圖結(jié)構(gòu),使用restore方法恢復(fù)訓(xùn)練數(shù)據(jù),最后使用同樣的save方法保存到本地:
saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加載模型圖結(jié)構(gòu) saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢復(fù)數(shù)據(jù) saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型
注:特殊情況下(如本例)需要從恢復(fù)的模型中加載出數(shù)據(jù):
# 從保存模型中恢復(fù)變量 graph = tf.get_default_graph() W = graph.get_tensor_by_name("w:0") b = graph.get_tensor_by_name("b:0")
以上這篇tensorflow模型繼續(xù)訓(xùn)練 fineturn實(shí)例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
- Tensorflow訓(xùn)練MNIST手寫數(shù)字識別模型
- tensorflow實(shí)現(xiàn)訓(xùn)練變量checkpoint的保存與讀取
- Tensorflow訓(xùn)練模型越來越慢的2種解決方案
- TensorFlow實(shí)現(xiàn)保存訓(xùn)練模型為pd文件并恢復(fù)
- 解決TensorFlow訓(xùn)練內(nèi)存不斷增長,進(jìn)程被殺死問題
- tensorflow獲取預(yù)訓(xùn)練模型某層參數(shù)并賦值到當(dāng)前網(wǎng)絡(luò)指定層方式
- tensorflow如何繼續(xù)訓(xùn)練之前保存的模型實(shí)例
- Tensorflow實(shí)現(xiàn)在訓(xùn)練好的模型上進(jìn)行測試
- tensorflow保持每次訓(xùn)練結(jié)果一致的簡單實(shí)現(xiàn)
相關(guān)文章
python實(shí)現(xiàn)拉普拉斯特征圖降維示例
今天小編就為大家分享一篇python實(shí)現(xiàn)拉普拉斯特征圖降維示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-11-11jupyter notebook運(yùn)行代碼沒反應(yīng)且in[ ]沒有*
本文主要介紹了jupyter notebook運(yùn)行代碼沒反應(yīng)且in[ ]沒有*,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2022-03-03Python+Dlib+Opencv實(shí)現(xiàn)人臉采集并表情判別功能的代碼
這篇文章主要介紹了Python+Dlib+Opencv實(shí)現(xiàn)人臉采集并表情判別,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-07-07PyQt5 QTreeWidget 樹形結(jié)構(gòu)遞歸遍歷當(dāng)前所有節(jié)點(diǎn)的實(shí)現(xiàn)
Qt中實(shí)現(xiàn)樹形結(jié)構(gòu)可以使用QTreeWidget類,也可以使用QTreeView類,本文主要介紹了PyQt5 QTreeWidget 樹形結(jié)構(gòu)遞歸遍歷當(dāng)前所有節(jié)點(diǎn)的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2021-11-11探索Python內(nèi)置數(shù)據(jù)類型的精髓與應(yīng)用
本文探索Python內(nèi)置數(shù)據(jù)類型的精髓與應(yīng)用,包括字符串、列表、元組、字典和集合。通過深入了解它們的特性、操作和常見用法,讀者將能夠更好地利用這些數(shù)據(jù)類型解決實(shí)際問題。2023-09-09