雙向RNN:bidirectional_dynamic_rnn()函數(shù)的使用詳解
雙向RNN:bidirectional_dynamic_rnn()函數(shù)的使用詳解
先說下為什么要使用到雙向RNN,在讀一篇文章的時候,上文提到的信息十分的重要,但這些信息是不足以捕捉文章信息的,下文隱含的信息同樣會對該時刻的語義產(chǎn)生影響。
舉一個不太恰當?shù)睦?,某次工作會議上,領導進行“簡潔地”總結,他會在第一句告訴你:“下面,為了節(jié)約時間,我簡單地說兩點…”,(…此處略去五百字…),“首先,….”,(…此處略去一萬字…),“礙于時間的關系,我要加快速度了,下面我簡要說下第二點…”(…此處再次略去五千字…)“好的,我想說的大概就是這些”(…此處又略去了二百字…),“謝謝大家!”如果將這篇發(fā)言交給一個單層的RNN網(wǎng)絡去學習,因為“首先”和“第二點”中間隔得實在太久,等到開始學習“第二點”時,網(wǎng)絡已經(jīng)忘記了“簡單地說兩點”這個重要的信息,最終的結果就只剩下在風中凌亂了。。。于是我們決定加一個反向的網(wǎng)絡,從后開始往前聽,對于這層網(wǎng)絡,他首先聽到的就是“第二點”,然后是“首先”,最后,他對比了一下果然僅僅是“簡要地兩點”,在于前向的網(wǎng)絡進行結合,就深入學習了領導的指導精神。

上圖是一個雙向LSTM的結構圖,對于最后輸出的每個隱藏狀態(tài)
都是前向網(wǎng)絡和后向網(wǎng)絡的元組,即
其中每一個
或者
又是一個由隱藏狀態(tài)和細胞狀態(tài)組成的元組(或者是concat)。同樣最終的output也是需要將前向和后向的輸出concat起來的,這樣就保證了在最終時刻,無論是輸出還是隱藏狀態(tài)都是有考慮了上文和下文信息的。
下面就來看下tensorflow中已經(jīng)集成的 tf.nn.bidirectional_dynamic_rnn() 函數(shù)。似乎雙向的暫時只有這一個動態(tài)的RNN方法,不過想想也能理解,這種結構暫時也只會在encoder端出現(xiàn),無論你的輸入是pad到了定長或者是不定長的,動態(tài)RNN都是可以處理的。
具體的定義如下:
tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
仔細看這個方法似乎和dynamic_rnn()沒有太大區(qū)別,無非是多加了一個bw的部分,事實上也的確如此。先看下前向傳播的部分:
with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs,
sequence_length=sequence_length,
initial_state=initial_state_fw,
dtype=dtype,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
scope=fw_scope)
完全就是一個dynamic_rnn(),至于你選擇LSTM或者GRU,只是cell的定義不同罷了。而雙向RNN的核心就在于反向的bw部分。剛才說過,反向部分就是從后往前讀,而這個翻轉的部分,就要用到一個reverse_sequence()的方法,來看一下這一部分:
with vs.variable_scope("bw") as bw_scope:
# ———————————— 此處是重點 ————————————
inputs_reverse = _reverse(
inputs, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
# ————————————————————————————————————
tmp, output_state_bw = dynamic_rnn(
cell=cell_bw,
inputs=inputs_reverse,
sequence_length=sequence_length,
initial_state=initial_state_bw,
dtype=dtype,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
time_major=time_major,
scope=bw_scope)
我們可以看到,這里的輸入不再是inputs,而是一個inputs_reverse,根據(jù)time_major的取值,time_dim和batch_dim組合的 {0,1} 取值正好相反,也就對應了時間維和批量維的詞序關系。
而最終的輸出:
outputs = (output_fw, output_bw) output_states = (output_state_fw, output_state_bw)
這里還有最后的一個小問題,output_states是一個元組的元組,我個人的處理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分別將c和h狀態(tài)concat起來,用tf.contrib.rnn.LSTMStateTuple()函數(shù)生成decoder端的初始狀態(tài)。
以上這篇雙向RNN:bidirectional_dynamic_rnn()函數(shù)的使用詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
vscode 配置 python3開發(fā)環(huán)境的方法
這篇文章主要介紹了vscode 配置 python3開發(fā)環(huán)境的方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-09-09
Python中一個for循環(huán)循環(huán)多個變量的示例
今天小編就為大家分享一篇Python中一個for循環(huán)循環(huán)多個變量的示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07
Python3+OpenCV實現(xiàn)簡單交通標志識別流程分析
這篇文章主要介紹了Python3+OpenCV實現(xiàn)簡單交通標志識別,主要思路是解析XML文檔,根據(jù)<name>標簽進行分類,如果是直行、右轉、左轉、停止就把它從原圖中裁剪下來并重命名,感謝的朋友跟隨小編一起看看示例代碼2021-12-12
matplotlib.subplot()畫子圖并共享y坐標軸的方法
Matplotlib的可以把很多張圖畫到一個顯示界面,本文主要介紹matplotlib.subplot()畫子圖并共享y坐標軸的方法,需要的朋友們下面隨著小編來一起學習學習吧2021-05-05
Python中使用Flask、MongoDB搭建簡易圖片服務器
這篇文章主要介紹了Python中使用Flask、MongoDB搭建簡易圖片服務器,本文是一個詳細完整的教程,需要的朋友可以參考下2015-02-02

