欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

基于TensorFlow中自定義梯度的2種方式

 更新時(shí)間:2020年02月04日 11:07:28   作者:FesianXu  
今天小編就為大家分享一篇基于TensorFlow中自定義梯度的2種方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧

前言

在深度學(xué)習(xí)中,有時(shí)候我們需要對(duì)某些節(jié)點(diǎn)的梯度進(jìn)行一些定制,特別是該節(jié)點(diǎn)操作不可導(dǎo)(比如階梯除法如 ),如果實(shí)在需要對(duì)這個(gè)節(jié)點(diǎn)進(jìn)行操作,而且希望其可以反向傳播,那么就需要對(duì)其進(jìn)行自定義反向傳播時(shí)的梯度。在有些場(chǎng)景,如[2]中介紹到的梯度反轉(zhuǎn)(gradient inverse)中,就必須在某層節(jié)點(diǎn)對(duì)反向傳播的梯度進(jìn)行反轉(zhuǎn),也就是需要更改正常的梯度傳播過程,如下圖的 所示。

在tensorflow中有若干可以實(shí)現(xiàn)定制梯度的方法,這里介紹兩種。

1. 重寫梯度法

重寫梯度法指的是通過tensorflow自帶的機(jī)制,將某個(gè)節(jié)點(diǎn)的梯度重寫(override),這種方法的適用性最廣。我們這里舉個(gè)例子[3].

符號(hào)函數(shù)的前向傳播采用的是階躍函數(shù)y=sign(x) y = \rm{sign}(x)y=sign(x),如下圖所示,我們知道階躍函數(shù)不是連續(xù)可導(dǎo)的,因此我們?cè)诜聪騻鞑r(shí),將其替代為一個(gè)可以連續(xù)求導(dǎo)的函數(shù)y=Htanh(x) y = \rm{Htanh(x)}y=Htanh(x),于是梯度就是大于1和小于-1時(shí)為0,在-1和1之間時(shí)是1。

使用重寫梯度的方法如下,主要是涉及到tf.RegisterGradient()和tf.get_default_graph().gradient_override_map(),前者注冊(cè)新的梯度,后者重寫圖中具有名字name='Sign'的操作節(jié)點(diǎn)的梯度,用在新注冊(cè)的QuantizeGrad替代。

#使用修飾器,建立梯度反向傳播函數(shù)。其中op.input包含輸入值、輸出值,grad包含上層傳來(lái)的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
 input = op.inputs[0] # 取出當(dāng)前的輸入
 cond = (input>=-1)&(input<=1) # 大于1或者小于-1的值的位置
 zeros = tf.zeros_like(grad) # 定義出0矩陣用于掩膜
 return tf.where(cond, grad, zeros) 
 # 將大于1或者小于-1的上一層的梯度置為0
 
#使用with上下文管理器覆蓋原始的sign梯度函數(shù)
def binary(input):
 x = input
 with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
 #重寫梯度
  x = tf.sign(x)
 return x
 
#使用
x = binary(x)

其中的def sign_grad(op, grad):是注冊(cè)新的梯度的套路,其中的op是當(dāng)前操作的輸入值/張量等,而grad指的是從反向而言的上一層的梯度。

通常來(lái)說,在tensorflow中自定義梯度,函數(shù)tf.identity()是很重要的,其API手冊(cè)如下:

tf.identity(
 input,
 name=None
)

其會(huì)返回一個(gè)形狀和內(nèi)容都和輸入完全一樣的輸出,但是你可以自定義其反向傳播時(shí)的梯度,因此在梯度反轉(zhuǎn)等操作中特別有用。

這里再舉個(gè)反向梯度[2]的例子,也就是梯度為 而不是

import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
@tf.RegisterGradient('CustomGrad')
def CustomGrad(op, grad):
#  tf.Print(grad)
 return -grad
 
g = tf.get_default_graph()
oo = x1+x2
with g.gradient_override_map({"Identity": "CustomGrad"}):
 output = tf.identity(oo)
grad_1 = tf.gradients(output, oo)
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(grad_1))

因?yàn)?grad,所以這里的梯度輸出是[-1]而不是[1]。有一個(gè)我們需要注意的是,在自定義函數(shù)def CustomGrad()中,返回的值得是一個(gè)張量,而不能返回一個(gè)參數(shù),比如return 0,這樣會(huì)報(bào)錯(cuò),如:

AttributeError: 'int' object has no attribute 'name'

顯然,這是因?yàn)閠ensorflow的內(nèi)部操作需要取返回值的名字而int類型沒有名字。

PS:def CustomGrad()這個(gè)函數(shù)簽名是隨便你取的。

2. stop_gradient法

對(duì)于自定義梯度,還有一種比較簡(jiǎn)潔的操作,就是利用tf.stop_gradient()函數(shù),我們看下例子[1]:

t = g(x)
y = t + tf.stop_gradient(f(x) - t)

這里,我們本來(lái)的前向傳遞函數(shù)是f(x),但是想要在反向時(shí)傳遞的函數(shù)是g(x),因?yàn)樵谇跋蜻^程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向傳遞;而在反向過程中,因?yàn)閠f.stop_gradient()的作用,使得f(x)-t的梯度變?yōu)榱?,從而只剩下g(x)在反向傳遞。

我們看下完整的例子:

import tensorflow as tf

x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)

f = x1+x2*x3
t = -f

y1 = t + tf.stop_gradient(f-t)
y2 = f

grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x1)
with tf.Session(config=config) as sess:
 sess.run(tf.global_variables_initializer())

 print(sess.run(grad_1))
 print(sess.run(grad_2))

第一個(gè)輸出為[-1],第二個(gè)輸出為[1],顯然也實(shí)現(xiàn)了梯度的反轉(zhuǎn)。

以上這篇基于TensorFlow中自定義梯度的2種方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • django 實(shí)現(xiàn)編寫控制登錄和訪問權(quán)限控制的中間件方法

    django 實(shí)現(xiàn)編寫控制登錄和訪問權(quán)限控制的中間件方法

    今天小編就為大家分享一篇django 實(shí)現(xiàn)編寫控制登錄和訪問權(quán)限控制的中間件方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來(lái)看看吧
    2019-01-01
  • Python如何存儲(chǔ)數(shù)據(jù)到j(luò)son文件

    Python如何存儲(chǔ)數(shù)據(jù)到j(luò)son文件

    這篇文章主要介紹了Python如何存儲(chǔ)數(shù)據(jù)到j(luò)son文件,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-03-03
  • flask與數(shù)據(jù)庫(kù)的交互操作示例

    flask與數(shù)據(jù)庫(kù)的交互操作示例

    這篇文章主要為大家介紹了flask與數(shù)據(jù)庫(kù)的交互操作示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-08-08
  • Python+xlwings制作天氣預(yù)報(bào)表

    Python+xlwings制作天氣預(yù)報(bào)表

    python操作Excel的模塊,網(wǎng)上提到的模塊大致有:xlwings、xlrd、xlwt、openpyxl、pyxll等。本文將利用xlwings模塊制作一個(gè)天氣預(yù)報(bào)表,需要的可以參考一下
    2022-01-01
  • Python?虛擬機(jī)集合set實(shí)現(xiàn)原理及源碼解析

    Python?虛擬機(jī)集合set實(shí)現(xiàn)原理及源碼解析

    這篇文章主要為大家介紹了Python?虛擬機(jī)集合set實(shí)現(xiàn)原理及源碼解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-03-03
  • 基于python實(shí)現(xiàn)分析識(shí)別文章/內(nèi)容中的高頻詞和關(guān)鍵詞

    基于python實(shí)現(xiàn)分析識(shí)別文章/內(nèi)容中的高頻詞和關(guān)鍵詞

    要分析一篇文章的高頻詞和關(guān)鍵詞,可以使用 Python 中的 nltk 庫(kù)和 collections 庫(kù)或者jieba庫(kù)來(lái)實(shí)現(xiàn),本篇文章介紹基于兩種庫(kù)分別實(shí)現(xiàn)分析內(nèi)容中的高頻詞和關(guān)鍵詞,需要的朋友可以參考下
    2023-09-09
  • Python模擬鍵盤輸入自動(dòng)登錄TGP

    Python模擬鍵盤輸入自動(dòng)登錄TGP

    這篇文章主要介紹了Python模擬鍵盤輸入自動(dòng)登錄TGP的示例代碼,幫助大家更好的理解和學(xué)習(xí)python,感興趣的朋友可以了解下
    2020-11-11
  • Python裝飾器詳情

    Python裝飾器詳情

    這篇文章主要介紹了Python裝飾器,裝飾器Decorator從字面上理解,就是裝飾對(duì)象的器件,其的特點(diǎn)是特點(diǎn)是函數(shù)是作為其參數(shù)出現(xiàn)的,裝飾器還擁有閉包的特點(diǎn),下面來(lái)看看文中的具體內(nèi)容
    2021-11-11
  • Python學(xué)習(xí)之時(shí)間包使用教程詳解

    Python學(xué)習(xí)之時(shí)間包使用教程詳解

    本文主要介紹了Python中的內(nèi)置時(shí)間包:datetime包?與?time包?,通過學(xué)習(xí)時(shí)間包可以讓我們的開發(fā)過程中對(duì)時(shí)間進(jìn)行輕松的處理,快來(lái)跟隨小編一起學(xué)習(xí)一下吧
    2022-03-03
  • python區(qū)塊鏈地址的簡(jiǎn)版實(shí)現(xiàn)

    python區(qū)塊鏈地址的簡(jiǎn)版實(shí)現(xiàn)

    這篇文章主要為大家介紹了python區(qū)塊鏈地址的簡(jiǎn)版實(shí)現(xiàn),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-05-05

最新評(píng)論