tensorflow輸出權(quán)重值和偏差的方法
使用tensorflow 訓(xùn)練模型時(shí),我們可以使用 tensorflow自帶的 Save模塊 tf.train.Saver()來保存模型,使用方式很簡單 就是在訓(xùn)練完模型后,調(diào)用saver.save()即可
saver = tf.train.Saver(write_version=tf.train.SaverDef.V2) saver.save(sess, save_dir+"crfmodel.ckpt", global_step=0)
重新載入模型
saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model) saver.restore(sess, ckpt.model_checkpoint_path)
但是這種方式保存的模型中包含特別多的信息,使保存的模型很大,其實(shí)里面有很多不是我們想要的.我們就想要里面最重要的權(quán)重信息和偏差等等數(shù)據(jù),然后再自己寫解密代碼,就可以把模型應(yīng)用于其他的平臺(tái),比如安卓手機(jī).
那么我們可以使用下面的方式獲取訓(xùn)練后的權(quán)重和偏移,
ww, bb = sess.run([self.W,self.b])
其中W,和b都是 Tensor類型的數(shù)據(jù)
with tf.name_scope('weights'): self.W = tf.get_variable( shape=[self.feat_size, self.nb_classes], initializer=tf.truncated_normal_initializer(stddev=0.01), name='weights' # ,regularizer=tf.contrib.layers.l1_regularizer(0.1) ) with tf.name_scope('biases'): self.b = tf.get_variable( shape=[self.nb_classes], initializer=tf.truncated_normal_initializer(stddev=0.01), name='bias' )
tensorflow 輸出權(quán)重 到csv或txt
import numpy as np W_val, b_val = sess.run([weights_tensor, biases_tensor]) np.savetxt("W.csv", W_val, delimiter=",") np.savetxt("b.csv", b_val, delimiter=",")
以上就是本文的全部內(nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
python實(shí)現(xiàn)自動(dòng)生成C++代碼的代碼生成器
這篇文章介紹了python實(shí)現(xiàn)C++代碼生成器的方法,文中通過示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-07-07使用python PIL庫批量對(duì)圖片添加水印的過程詳解
平常我們想給某些圖片添加文字水印,方法有很多,也有很多的工具可以方便的進(jìn)行,今天主要是對(duì)PIL庫的應(yīng)用,結(jié)合Python語言批量對(duì)圖片添加水印,文章通過代碼示例給大家介紹的非常詳細(xì),感興趣的同學(xué)可以參考一下2023-11-11Selenium自動(dòng)化測試實(shí)現(xiàn)窗口切換
這篇文章主要介紹了Selenium自動(dòng)化測試實(shí)現(xiàn)窗口切換,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-03-03linux下python使用sendmail發(fā)送郵件
這篇文章主要為大家詳細(xì)介紹了linux下python使用sendmail發(fā)送郵件,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-05-05