Tensorflow 卷積的梯度反向傳播過(guò)程
一. valid卷積的梯度
我們分兩種不同的情況討論valid卷積的梯度:第一種情況,在已知卷積核的情況下,對(duì)未知張量求導(dǎo)(即對(duì)張量中每一個(gè)變量求導(dǎo));第二種情況,在已知張量的情況下,對(duì)未知卷積核求導(dǎo)(即對(duì)卷積核中每一個(gè)變量求導(dǎo))
1.已知卷積核,對(duì)未知張量求導(dǎo)
我們用一個(gè)簡(jiǎn)單的例子理解valid卷積的梯度反向傳播。假設(shè)有一個(gè)3x3的未知張量x,以及已知的2x2的卷積核K
Tensorflow提供函數(shù)tf.nn.conv2d_backprop_input實(shí)現(xiàn)了valid卷積中對(duì)未知變量的求導(dǎo),以上示例對(duì)應(yīng)的代碼如下:
import tensorflow as tf
# 卷積核
kernel=tf.constant(
[
[[[3]],[[4]]],
[[[5]],[[6]]]
]
,tf.float32
)
# 某一函數(shù)針對(duì)sigma的導(dǎo)數(shù)
out=tf.constant(
[
[
[[-1],[1]],
[[2],[-2]]
]
]
,tf.float32
)
# 針對(duì)未知變量的導(dǎo)數(shù)的方向計(jì)算
inputValue=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,out,[1,1,1,1],'VALID')
session=tf.Session()
print(session.run(inputValue))
[[[[ -3.]
[ -1.]
[ 4.]]
[[ 1.]
[ 1.]
[ -2.]]
[[ 10.]
[ 2.]
[-12.]]]]
2.已知輸入張量,對(duì)未知卷積核求導(dǎo)
假設(shè)已知3行3列的張量x和未知的2行2列的卷積核K
Tensorflow提供函數(shù)tf.nn.conv2d_backprop_filter實(shí)現(xiàn)valid卷積對(duì)未知卷積核的求導(dǎo),以上示例的代碼如下:
import tensorflow as tf
# 輸入張量
x=tf.constant(
[
[
[[1],[2],[3]],
[[4],[5],[6]],
[[7],[8],[9]]
]
]
,tf.float32
)
# 某一個(gè)函數(shù)F對(duì)sigma的導(dǎo)數(shù)
partial_sigma=tf.constant(
[
[
[[-1],[-2]],
[[-3],[-4]]
]
]
,tf.float32
)
# 某一個(gè)函數(shù)F對(duì)卷積核k的導(dǎo)數(shù)
partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'VALID')
session=tf.Session()
print(session.run(partial_sigma_k))
[[[[-37.]]
[[-47.]]]
[[[-67.]]
[[-77.]]]]
二. same卷積的梯度
1.已知卷積核,對(duì)輸入張量求導(dǎo)
假設(shè)有3行3列的已知張量x,2行2列的未知卷積核K
import tensorflow as tf
# 卷積核
kernel=tf.constant(
[
[[[3]],[[4]]],
[[[5]],[[6]]]
]
,tf.float32
)
# 某一函數(shù)針對(duì)sigma的導(dǎo)數(shù)
partial_sigma=tf.constant(
[
[
[[-1],[1],[3]],
[[2],[-2],[-4]],
[[-3],[4],[1]]
]
]
,tf.float32
)
# 針對(duì)未知變量的導(dǎo)數(shù)的方向計(jì)算
partial_x=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,partial_sigma,[1,1,1,1],'SAME')
session=tf.Session()
print(session.run(inputValue))
[[[[ -3.]
[ -1.]
[ 4.]]
[[ 1.]
[ 1.]
[ -2.]]
[[ 10.]
[ 2.]
[-12.]]]]
2.已知輸入張量,對(duì)未知卷積核求導(dǎo)
假設(shè)已知3行3列的張量x和未知的2行2列的卷積核K
import tensorflow as tf
# 卷積核
x=tf.constant(
[
[
[[1],[2],[3]],
[[4],[5],[6]],
[[7],[8],[9]]
]
]
,tf.float32
)
# 某一函數(shù)針對(duì)sigma的導(dǎo)數(shù)
partial_sigma=tf.constant(
[
[
[[-1],[-2],[1]],
[[-3],[-4],[2]],
[[-2],[1],[3]]
]
]
,tf.float32
)
# 針對(duì)未知變量的導(dǎo)數(shù)的方向計(jì)算
partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'SAME')
session=tf.Session()
print(session.run(partial_sigma_k))
[[[[ -1.]]
[[-54.]]]
[[[-43.]]
[[-77.]]]]
以上這篇Tensorflow 卷積的梯度反向傳播過(guò)程就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- tensorflow 實(shí)現(xiàn)自定義梯度反向傳播代碼
- 有關(guān)Tensorflow梯度下降常用的優(yōu)化方法分享
- TensorFlow梯度求解tf.gradients實(shí)例
- 基于TensorFlow中自定義梯度的2種方式
- tensorflow 查看梯度方式
- tensorflow求導(dǎo)和梯度計(jì)算實(shí)例
- Tensorflow的梯度異步更新示例
- 在Tensorflow中實(shí)現(xiàn)梯度下降法更新參數(shù)值
- Tensorflow實(shí)現(xiàn)部分參數(shù)梯度更新操作
- 運(yùn)用TensorFlow進(jìn)行簡(jiǎn)單實(shí)現(xiàn)線性回歸、梯度下降示例
相關(guān)文章
Python實(shí)現(xiàn)PS濾鏡特效Marble Filter玻璃條紋扭曲效果示例
這篇文章主要介紹了Python實(shí)現(xiàn)PS濾鏡特效Marble Filter玻璃條紋扭曲效果,涉及Python基于skimage庫(kù)實(shí)現(xiàn)圖形條紋扭曲效果的相關(guān)操作技巧,需要的朋友可以參考下2018-01-01
pytorch下的unsqueeze和squeeze的用法說(shuō)明
這篇文章主要介紹了pytorch下的unsqueeze和squeeze的用法說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-02-02
python thrift搭建服務(wù)端和客戶(hù)端測(cè)試程序
這篇文章主要為大家詳細(xì)介紹了python thrift搭建服務(wù)端和客戶(hù)端測(cè)試程序,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-01-01
python中如何利用matplotlib畫(huà)多個(gè)并列的柱狀圖
python是一個(gè)很有趣的語(yǔ)言,可以在命令行窗口運(yùn)行,下面這篇文章主要給大家介紹了關(guān)于python中如何利用matplotlib畫(huà)多個(gè)并列的柱狀圖的相關(guān)資料,需要的朋友可以參考下2022-01-01
Python3交互式shell ipython3安裝及使用詳解
這篇文章主要介紹了Python3交互式shell ipython3安裝及使用詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-07-07

