tensorflow 加載部分變量的實例講解
tensorflow模型保存為saver = tf.train.Saver()函數(shù),saver.save()保存模型,代碼如下:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) saver.save(sess,"checkpoint/model_test",global_step=1)
當我們保存模型后,我們可以通過saver.restore()來加載模型,初始化變量:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") saver = tf.train.Saver() with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
神經網絡訓練時,有時候我們需要從預訓練的模型中加載部分參數(shù),初始化當前模型,例如加入CNN有6層,我們需要從已有的模型初始化CNN前5層參數(shù).這可以通過saver.restore()實現(xiàn).
之前我們已經介紹可以通過tf.train.Saver()的保存部分變量的方法,即需要保存的變量列表,同樣的,在變量初始化的時候,我們可以對需要單獨初始化的變量分別定義一個tf.train.Saver()函數(shù),這樣就可以單獨對該部分變量初始化,例如下面代碼,saver1用于初始化變量v1,saver2用于初始化變量v2,v3:
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") v3= tf.Variable(tf.zeros([100]), name="v3") #saver = tf.train.Saver() saver1 = tf.train.Saver([v1]) saver2 = tf.train.Saver([v2]+[v3]) with tf.Session() as sess: # init_op = tf.global_variables_initializer() # sess.run(init_op) saver1.restore(sess, "checkpoint/model_test-1") saver2.restore(sess, "checkpoint/model_test-1") # saver.save(sess,"checkpoint/model_test",global_step=1)
以上這篇tensorflow 加載部分變量的實例講解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python按順序重命名文件并分類轉移到各個文件夾中的實現(xiàn)代碼
這篇文章主要介紹了python按順序重命名文件并分類轉移到各個文件夾中,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-07-07Centos環(huán)境部署django項目的全過程(永久復用)
Django是一款針對Python環(huán)境的WEB開發(fā)框架,能夠幫助我們構架快捷,下面這篇文章主要給大家介紹了關于Centos環(huán)境部署django項目的相關資料,需要的朋友可以參考下2022-10-10Python實現(xiàn)基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務端功能示例
這篇文章主要介紹了Python實現(xiàn)基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務端功能,結合實例形式分析了Python基于TCP UDP協(xié)議的IPv4 IPv6模式客戶端和服務端數(shù)據(jù)發(fā)送與接收相關操作技巧,需要的朋友可以參考下2018-03-03基于Python?OpenCV和?dlib實現(xiàn)眨眼檢測
這篇文章主要介紹了基于Python?OPenCV及dlib實現(xiàn)檢測視頻流中的眨眼次數(shù)。文中的代碼對我們的學習和工作有一定價值,感興趣的同學可以參考一下2021-12-12