欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

tensorflow實現(xiàn)從.ckpt文件中讀取任意變量

 更新時間:2020年05月26日 10:01:17   作者:黑龍江小伙er  
這篇文章主要介紹了tensorflow實現(xiàn)從.ckpt文件中讀取任意變量,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

思路有些混亂,希望大家能理解我的意思。

看了faster rcnn的tensorflow代碼,關(guān)于fix_variables的作用我不是很明白,所以寫了以下代碼,讀取了預(yù)訓(xùn)練模型vgg16得fc6和fc7的參數(shù),以及faster rcnn中heat_to_tail中的fc6和fc7,將它們做了對比,發(fā)現(xiàn)結(jié)果不一樣,說明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被訓(xùn)練。

具體讀取任意變量的代碼如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路徑
name_variable_to_restore = 'vgg_16/fc7/weights' #要讀取權(quán)重的變量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #輸出這個變量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定義接收權(quán)重的變量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定義恢復(fù)變量的對象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必須初始化
restorer_fc.restore(sess, file_name) #恢復(fù)變量
print(sess.run(fc7_conv)) #輸出結(jié)果

用以上的代碼分別讀取兩個網(wǎng)絡(luò)的fc6 和 fc7 ,對應(yīng)參數(shù)尺寸和權(quán)值都不同,但參數(shù)量相同。

再看lib/nets/vgg16.py中的:

(注意注釋)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定義接收權(quán)重的變量,不可被訓(xùn)練
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定義恢復(fù)變量的對象
   restorer_fc.restore(sess, pretrained_model) #恢復(fù)這些變量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #將vgg16中的fc6、fc7中的權(quán)重reshape賦給faster-rcnn中的fc6、fc7

我的理解:faster rcnn的網(wǎng)絡(luò)繼承了分類網(wǎng)絡(luò)的特征提取權(quán)重和分類器的權(quán)重,讓網(wǎng)絡(luò)從一個比較好的起點開始被訓(xùn)練,有利于訓(xùn)練結(jié)果的快速收斂。

補充知識:TensorFlow:加載部分ckpt文件變量&不同命名空間中加載模型

TensorFlow中,在加載和保存模型時,一般會直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,當需要選擇性加載模型參數(shù)時,則需要利用pywrap_tensorflow讀取模型,分析模型內(nèi)的變量關(guān)系。

例子:Faster-RCNN中,模型加載vgg16.ckpt,需要利用pywrap_tensorflow讀取ckpt文件中的參數(shù)

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此處構(gòu)建vgg16模型
variables = tf.global_variables()#獲取模型中所有變量
 
file_name='vgg16.ckpt'#vgg16網(wǎng)絡(luò)模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#獲取ckpt模型中的變量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空間名
variables_to_restore={}#構(gòu)建字典:需要的變量和對應(yīng)的模型變量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#將需要加載的變量作為參數(shù)輸入
restorer.restore(sess, file_name)

實際中,F(xiàn)aster RCNN中所構(gòu)建的vgg16網(wǎng)絡(luò)的fc6和fc7權(quán)重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7權(quán)重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上這篇tensorflow實現(xiàn)從.ckpt文件中讀取任意變量就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • flask框架自定義過濾器示例【markdown文件讀取和展示功能】

    flask框架自定義過濾器示例【markdown文件讀取和展示功能】

    這篇文章主要介紹了flask框架自定義過濾器,結(jié)合實例形式分析了flask基于自定義過濾器實現(xiàn)markdown文件讀取和展示功能相關(guān)操作技巧,需要的朋友可以參考下
    2019-11-11
  • Pytorch數(shù)據(jù)類型Tensor張量操作的實現(xiàn)

    Pytorch數(shù)據(jù)類型Tensor張量操作的實現(xiàn)

    本文主要介紹了Pytorch數(shù)據(jù)類型Tensor張量操作的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-07-07
  • Python中順序表的實現(xiàn)簡單代碼分享

    Python中順序表的實現(xiàn)簡單代碼分享

    這篇文章主要介紹了Python中順序表的實現(xiàn)簡單代碼分享,展示了代碼運行結(jié)果,然后分享了相關(guān)實例代碼,具有一定借鑒價值,需要的朋友可以參考下
    2018-01-01
  • python深度學(xué)習(xí)tensorflow安裝調(diào)試教程

    python深度學(xué)習(xí)tensorflow安裝調(diào)試教程

    這篇文章主要為大家介紹了python深度學(xué)習(xí)tensorflow安裝調(diào)試教程示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2022-06-06
  • Windows10下Tensorflow2.0 安裝及環(huán)境配置教程(圖文)

    Windows10下Tensorflow2.0 安裝及環(huán)境配置教程(圖文)

    這篇文章主要介紹了Windows10下Tensorflow2.0 安裝及環(huán)境配置教程(圖文),文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-11-11
  • pandas如何使用列表和字典創(chuàng)建?Series

    pandas如何使用列表和字典創(chuàng)建?Series

    這篇文章主要介紹了pandas如何使用列表和字典創(chuàng)建?Series,pandas 是基于NumPy的一種工具,該工具是為解決數(shù)據(jù)分析任務(wù)而創(chuàng)建的,下文我們就來看看文章是怎樣介紹pandas,需要的朋友也可以參考一下
    2021-12-12
  • Python使用Selenium+BeautifulSoup爬取淘寶搜索頁

    Python使用Selenium+BeautifulSoup爬取淘寶搜索頁

    這篇文章主要為大家詳細介紹了Python使用Selenium+BeautifulSoup爬取淘寶搜索頁,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-02-02
  • Python使用xpath對解析內(nèi)容進行數(shù)據(jù)提取

    Python使用xpath對解析內(nèi)容進行數(shù)據(jù)提取

    XPath 使用路徑表達式來選取HTML/ XML 文檔中的節(jié)點或節(jié)點集,節(jié)點是通過沿著路徑 (path) 或者步 (steps) 來選取的,本文將給大家介紹Python使用xpath對解析內(nèi)容進行數(shù)據(jù)提取的方法,需要的朋友可以參考下
    2024-05-05
  • Python中最強大的錯誤重試庫(tenacity庫)

    Python中最強大的錯誤重試庫(tenacity庫)

    本文要給大家介紹的tenacity庫,可能是目前Python生態(tài)中最好用的錯誤重試庫,主要介紹tenacity的主要使用方法和特性,具有一定的參考價值,感興趣的可以了解一下
    2022-04-04
  • Python3內(nèi)置模塊之base64編解碼方法詳解

    Python3內(nèi)置模塊之base64編解碼方法詳解

    這篇文章主要介紹了Python3內(nèi)置模塊之base64編解碼方法詳解,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-07-07

最新評論