Pytorch中關(guān)于RNN輸入和輸出的形狀總結(jié)
Pytorch對RNN輸入和輸出的形狀總結(jié)
個人對于RNN的一些總結(jié)。
RNN的輸入和輸出
RNN的經(jīng)典圖如下所示
各個參數(shù)的含義
- Xt: t時刻的輸入,形狀為[batch_size, input_dim]。對于整個RNN來說,總的X輸入為[seq_len, batch_size, input_dim],具體如何理解batch_size和seq_len在下面有說明。
- St: t時刻隱藏層的狀態(tài),也有時用ht表示,形狀為[batch_size, hidden_size],St=f(U·Xt+W·St-1),通過W和U矩陣的映射,將embedding后的Xt和上一狀態(tài)St-1轉(zhuǎn)為St
- Ot: t時刻的輸出,Ot=g(V·St),形狀為[batch_size, hidden_size],總的為輸出O為[seq_len, batch_size, hidden_size]
Pytorch中的使用
Pytorch中RNN函數(shù)如下
RNN的主要參數(shù)如下
nn.RNN(input_size, hidden_size, num_layers=1, bias=True)
參數(shù)解釋
input_size
: 輸入特征的維度,一般rnn中輸入的是詞向量,那么就為embedding-dimhidden_size
: 隱藏層神經(jīng)元的個數(shù),或者也叫輸出的維度num_layers
: 隱藏層的個數(shù),默認為1
output=輸出O, 隱藏狀態(tài)St,其中輸出O=[time_step, batch_size, hidden_size],St為t時刻的隱藏層狀態(tài)
理解RNN中的batch_size和seq_len
深度學(xué)習(xí)中采用mini-batch的方法進行迭代優(yōu)化,在CNN中batch的思想較容易理解,一次輸入batch個圖片,進行迭代。但是RNN中引入了seq_len(time_step), 理解較為困難,下面是我自己的一些理解。
首先假如我有五句話,作為訓(xùn)練的語料。
sentences = ["i like dog", "i love coffee", "i hate milk", "i like music", "i hate you"]
那么在輸入RNN之前要先進行embedding,比如one-hot encoding,容易得到這里的embedding-dim為9.
那么輸入的sentences可以表示為如下方式
t=0 | t=1 | t=2 | |
---|---|---|---|
batch1 | i | like | dog |
batch2 | i | love | coffee |
batch3 | i | hate | milk |
batch4 | i | like | music |
batch5 | i | hate | you |
那么在RNN的訓(xùn)練中。
- t=0時, 輸入第一個batch[i, i, i, i, i]這里用字符表示,其實應(yīng)該是對應(yīng)的one-hot編碼。
- t=1時,輸入第二個batch[like, love, hate, like, hate]
- t=2時,輸入第三個batch[dog, coffee, milk, music, you]
那么對應(yīng)的時間t來說,RNN需要對先后輸入的batch_size個字符進行前向計算迭代,得到輸出。
Pytorch雙向RNN隱藏層和輸出層結(jié)果拆分
1 RNN隱藏層和輸出層結(jié)果的形狀
從Pytorch官方文檔可以得到,對于批量化輸入的RNN來講,其隱藏層的shape為(num_directions*num_layers, batch_size, hidden_size)。
其輸出的shape為(seq_len, batch_size, D*hidden_size)。
2 雙向RNN情況下,隱藏層和輸出層結(jié)果拆分
當采用雙向RNN時,其輸出的結(jié)果包含正向和反向兩個方向輸出的結(jié)果。
2.1 輸出層結(jié)果拆分
其中對于輸出output來講,從官方文檔我們可以得到,其拆分正向和反向兩個方向結(jié)果的方法為:
output.shape = (seq_len, batch_size, num_directions*hidden_size)
output.view(seq_len, batch, num_directions, hidden_size)
其中,對于(num_directions)方向維度,正向和反向的維度值分別為??0???和??1?。
2.2 隱藏層結(jié)果拆分
而對于隱藏層,包括初始值h_0以及最終輸出h_n,也都包含兩個方向的隱藏狀態(tài),但是其拆分方式跟輸出層不一樣。
方法如下:
h_0, h_n.shape = (num_directions*num_layers, batch_size, hidden_size)
h_0, h_n.view(num_layers, num_directions, batch_size, hidden_size)
可以從簡單單層雙向RNN的輸出結(jié)果來驗證,此時RNN的輸出結(jié)果與最后一層的隱藏層結(jié)果是一樣的。
import torch import torch.nn as nn if __name__ == "__main__": # input_size: 3, hidden_size: 5, num_layers: 3 BiRNN_Net = nn.RNN(3, 5, 3, bidirectional=True, batch_first=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # batch_size: 1, seq_len: 1, input_size: 3 inputs = torch.zeros(1, 1, 3, device=device) # state: (num_directions*num_layers, batch_size, hidden_size) state = torch.randn(6, 1, 5, device=device) BiRNN_Net.to(device) output, hidden = BiRNN_Net(inputs, state) output_re = output.reshape((1, 1, 2, 5)) hidden_re = hidden.reshape((3, 2, 1, 5)) print(output) print(output_re) print(hidden) print(hidden_re)
輸出結(jié)果可以看出,隱藏層的結(jié)果是優(yōu)先num_layers網(wǎng)絡(luò)層數(shù)這一個維度來構(gòu)成的。
tensor([[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225, ?0.0533, ?0.4197, ? ? ? ? ? -0.7200, -0.1262, -0.7975]]], device='cuda:0', ? ? ? ?grad_fn=<CudnnRnnBackward0>) tensor([[[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225], ? ? ? ? ? [ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0', ? ? ? ?grad_fn=<ReshapeAliasBackward0>) tensor([[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]], ? ? ? ? [[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]], ? ? ? ? [[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]], ? ? ? ? [[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]], ? ? ? ? [[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]], ? ? ? ? [[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]], device='cuda:0', ? ? ? ?grad_fn=<CudnnRnnBackward0>) tensor([[[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]], ? ? ? ? ?[[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]]], ? ? ? ? [[[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]], ? ? ? ? ?[[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]]], ? ? ? ? [[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]], ? ? ? ? ?[[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0', ? ? ? ?grad_fn=<ReshapeAliasBackward0>)
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)遠程調(diào)用MetaSploit的方法
這篇文章主要介紹了Python實現(xiàn)遠程調(diào)用MetaSploit的方法,是很有借鑒價值的一個技巧,需要的朋友可以參考下2014-08-08淺談Python3中strip()、lstrip()、rstrip()用法詳解
這篇文章主要介紹了淺談Python3中strip()、lstrip()、rstrip()用法詳解,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-04-04Python調(diào)用DeepSeek?API的案例詳細教程
這篇文章主要為大家詳細介紹了以?Python?為例的調(diào)用?DeepSeek?API?的小白入門級詳細教程,文中的示例代碼講解詳細,感興趣的小伙伴可以了解下2025-02-02Python?數(shù)據(jù)篩選功能實現(xiàn)
這篇文章主要介紹了Python?數(shù)據(jù)篩選,無論是在數(shù)據(jù)分析還是數(shù)據(jù)挖掘的時候,數(shù)據(jù)篩選總會涉及到,這里我總結(jié)了一下python中列表,字典,數(shù)據(jù)框中一些常用的數(shù)據(jù)篩選的方法,需要的朋友可以參考下2023-04-04python實現(xiàn)掃描日志關(guān)鍵字的示例
下面小編就為大家分享一篇python實現(xiàn)掃描日志關(guān)鍵字的示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04