Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作
在深度學(xué)習(xí)中,遷移學(xué)習(xí)經(jīng)常被使用,在大數(shù)據(jù)集上預(yù)訓(xùn)練的模型遷移到特定的任務(wù),往往需要保持模型參數(shù)不變,而微調(diào)與任務(wù)相關(guān)的模型層。
本文主要介紹,使用tensorflow部分更新模型參數(shù)的方法。
1. 根據(jù)Variable scope剔除需要固定參數(shù)的變量
def get_variable_via_scope(scope_lst): vars = [] for sc in scope_lst: sc_variable = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope=scope) vars.extend(sc_variable) return vars trainable_vars = tf.trainable_variables() no_change_scope = ['your_unchange_scope_name'] no_change_vars = get_variable_via_scope(no_change_scope) for v in no_change_vars: trainable_vars.remove(v) grads, _ = tf.gradients(loss, trainable_vars) optimizer = tf.train.AdamOptimizer(lr) train_op = optimizer.apply_gradient(zip(grads, trainable_vars), global_step=global_step)
2. 使用tf.stop_gradient()函數(shù)
在建立Graph過(guò)程中使用該函數(shù),非常簡(jiǎn)潔地避免了使用scope獲取參數(shù)
3. 一個(gè)矩陣中部分行或列參數(shù)更新
如果一個(gè)矩陣,只有部分行或列需要更新參數(shù),其它保持不變,該場(chǎng)景很常見(jiàn),例如word embedding中,一些預(yù)定義的領(lǐng)域相關(guān)詞保持不變(使用領(lǐng)域相關(guān)word embedding初始化),而另一些通用詞變化。
import tensorflow as tf import numpy as np def entry_stop_gradients(target, mask): mask_h = tf.abs(mask-1) return tf.stop_gradient(mask_h * target) + mask * target mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1]) mask_h = np.abs(mask-1) emb = tf.constant(np.ones([10, 5])) matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1)) parm = np.random.randn(5, 1) t_parm = tf.constant(parm) loss = tf.reduce_sum(tf.matmul(matrix, t_parm)) grad1 = tf.gradients(loss, emb) grad2 = tf.gradients(loss, matrix) print matrix with tf.Session() as sess: print sess.run(loss) print sess.run([grad1, grad2])
以上這篇Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- tensorflow 實(shí)現(xiàn)自定義梯度反向傳播代碼
- 有關(guān)Tensorflow梯度下降常用的優(yōu)化方法分享
- TensorFlow梯度求解tf.gradients實(shí)例
- 基于TensorFlow中自定義梯度的2種方式
- tensorflow 查看梯度方式
- tensorflow求導(dǎo)和梯度計(jì)算實(shí)例
- Tensorflow的梯度異步更新示例
- 在Tensorflow中實(shí)現(xiàn)梯度下降法更新參數(shù)值
- 運(yùn)用TensorFlow進(jìn)行簡(jiǎn)單實(shí)現(xiàn)線性回歸、梯度下降示例
- Tensorflow 卷積的梯度反向傳播過(guò)程
相關(guān)文章
python基礎(chǔ)入門學(xué)習(xí)筆記(Python環(huán)境搭建)
這篇文章主要介紹了python基礎(chǔ)入門學(xué)習(xí)筆記,這是開(kāi)啟學(xué)習(xí)python基礎(chǔ)知識(shí)的第一篇,夯實(shí)Python基礎(chǔ),才能走的更遠(yuǎn),感興趣的小伙伴們可以參考一下2016-01-01tensorflow-gpu安裝的常見(jiàn)問(wèn)題及解決方案
這篇文章主要介紹了tensorflow-gpu安裝的常見(jiàn)問(wèn)題及解決方案,本文給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友參考下吧,需要的朋友可以參考下2020-01-01python如何爬取網(wǎng)站數(shù)據(jù)并進(jìn)行數(shù)據(jù)可視化
這篇文章主要介紹了python爬取拉勾網(wǎng)數(shù)據(jù)并進(jìn)行數(shù)據(jù)可視化,爬取拉勾網(wǎng)關(guān)于python職位相關(guān)的數(shù)據(jù)信息,并將爬取的數(shù)據(jù)已csv各式存入文件,然后對(duì)csv文件相關(guān)字段的數(shù)據(jù)進(jìn)行清洗,并對(duì)數(shù)據(jù)可視化展示,包括柱狀圖展示、直方圖展示,需要的朋友可以參考下2019-07-07Python pandas如何獲取數(shù)據(jù)的行數(shù)和列數(shù)
這篇文章主要介紹了Python pandas如何獲取數(shù)據(jù)的行數(shù)和列數(shù)問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-02-02python的scikit-learn將特征轉(zhuǎn)成one-hot特征的方法
今天小編就為大家分享一篇python的scikit-learn將特征轉(zhuǎn)成one-hot特征的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-07-07python中數(shù)組array和列表list的基本用法及區(qū)別解析
大家都知道數(shù)組array是同類型數(shù)據(jù)的有限集合,列表list是一系列按特定順序排列的元素組成,可以將任何數(shù)據(jù)放入列表,且其中元素之間沒(méi)有任何關(guān)系,本文介紹python中數(shù)組array和列表list的基本用法及區(qū)別,感興趣的朋友一起看看吧2022-05-05python實(shí)現(xiàn)查找excel里某一列重復(fù)數(shù)據(jù)并且剔除后打印的方法
這篇文章主要介紹了python實(shí)現(xiàn)查找excel里某一列重復(fù)數(shù)據(jù)并且剔除后打印的方法,涉及Python使用xlrd模塊操作Excel的相關(guān)技巧,需要的朋友可以參考下2015-05-05