欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中關(guān)于RNN輸入和輸出的形狀總結(jié)

 更新時間:2023年06月15日 08:35:29   作者:會唱歌的豬233  
這篇文章主要介紹了Pytorch中關(guān)于RNN輸入和輸出的形狀總結(jié),具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

Pytorch對RNN輸入和輸出的形狀總結(jié)

個人對于RNN的一些總結(jié)。

RNN的輸入和輸出

RNN的經(jīng)典圖如下所示


各個參數(shù)的含義

  • Xt: t時刻的輸入,形狀為[batch_size, input_dim]。對于整個RNN來說,總的X輸入為[seq_len, batch_size, input_dim],具體如何理解batch_size和seq_len在下面有說明。
  • St: t時刻隱藏層的狀態(tài),也有時用ht表示,形狀為[batch_size, hidden_size],St=f(U·Xt+W·St-1),通過W和U矩陣的映射,將embedding后的Xt和上一狀態(tài)St-1轉(zhuǎn)為St
  • Ot: t時刻的輸出,Ot=g(V·St),形狀為[batch_size, hidden_size],總的為輸出O為[seq_len, batch_size, hidden_size]

Pytorch中的使用

Pytorch中RNN函數(shù)如下

RNN的主要參數(shù)如下

nn.RNN(input_size, hidden_size, num_layers=1, bias=True)

參數(shù)解釋

  • input_size: 輸入特征的維度,一般rnn中輸入的是詞向量,那么就為embedding-dim
  • hidden_size: 隱藏層神經(jīng)元的個數(shù),或者也叫輸出的維度
  • num_layers: 隱藏層的個數(shù),默認為1

output=輸出O, 隱藏狀態(tài)St,其中輸出O=[time_step, batch_size, hidden_size],St為t時刻的隱藏層狀態(tài)

理解RNN中的batch_size和seq_len

深度學(xué)習(xí)中采用mini-batch的方法進行迭代優(yōu)化,在CNN中batch的思想較容易理解,一次輸入batch個圖片,進行迭代。但是RNN中引入了seq_len(time_step), 理解較為困難,下面是我自己的一些理解。

首先假如我有五句話,作為訓(xùn)練的語料。

sentences = ["i like dog", "i love coffee", "i hate milk", "i like music", "i hate you"]

那么在輸入RNN之前要先進行embedding,比如one-hot encoding,容易得到這里的embedding-dim為9.

那么輸入的sentences可以表示為如下方式

t=0t=1t=2
batch1ilikedog
batch2ilovecoffee
batch3ihatemilk
batch4ilikemusic
batch5ihateyou

那么在RNN的訓(xùn)練中。

  • t=0時, 輸入第一個batch[i, i, i, i, i]這里用字符表示,其實應(yīng)該是對應(yīng)的one-hot編碼。
  • t=1時,輸入第二個batch[like, love, hate, like, hate]
  • t=2時,輸入第三個batch[dog, coffee, milk, music, you]

那么對應(yīng)的時間t來說,RNN需要對先后輸入的batch_size個字符進行前向計算迭代,得到輸出。

Pytorch雙向RNN隱藏層和輸出層結(jié)果拆分

1 RNN隱藏層和輸出層結(jié)果的形狀

從Pytorch官方文檔可以得到,對于批量化輸入的RNN來講,其隱藏層的shape為(num_directions*num_layers, batch_size, hidden_size)。

其輸出的shape為(seq_len, batch_size, D*hidden_size)。

2 雙向RNN情況下,隱藏層和輸出層結(jié)果拆分

當采用雙向RNN時,其輸出的結(jié)果包含正向和反向兩個方向輸出的結(jié)果。

2.1 輸出層結(jié)果拆分

其中對于輸出output來講,從官方文檔我們可以得到,其拆分正向和反向兩個方向結(jié)果的方法為:

output.shape = (seq_len, batch_size, num_directions*hidden_size)

output.view(seq_len, batch, num_directions, hidden_size)

其中,對于(num_directions)方向維度,正向和反向的維度值分別為??0???和??1?。

2.2 隱藏層結(jié)果拆分

而對于隱藏層,包括初始值h_0以及最終輸出h_n,也都包含兩個方向的隱藏狀態(tài),但是其拆分方式跟輸出層不一樣。

方法如下:

h_0, h_n.shape = (num_directions*num_layers, batch_size, hidden_size)

h_0, h_n.view(num_layers, num_directions, batch_size, hidden_size)

可以從簡單單層雙向RNN的輸出結(jié)果來驗證,此時RNN的輸出結(jié)果與最后一層的隱藏層結(jié)果是一樣的。

import torch
import torch.nn as nn
if __name__ == "__main__":
    # input_size: 3, hidden_size: 5, num_layers: 3
    BiRNN_Net = nn.RNN(3, 5, 3, bidirectional=True, batch_first=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # batch_size: 1, seq_len: 1, input_size: 3
    inputs = torch.zeros(1, 1, 3, device=device)
    # state: (num_directions*num_layers, batch_size, hidden_size)
    state = torch.randn(6, 1, 5, device=device)
    BiRNN_Net.to(device)
    output, hidden = BiRNN_Net(inputs, state)
    output_re = output.reshape((1, 1, 2, 5))
    hidden_re = hidden.reshape((3, 2, 1, 5))
    print(output)
    print(output_re)
    print(hidden)
    print(hidden_re)

輸出結(jié)果可以看出,隱藏層的結(jié)果是優(yōu)先num_layers網(wǎng)絡(luò)層數(shù)這一個維度來構(gòu)成的。

tensor([[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225, ?0.0533, ?0.4197,
? ? ? ? ? -0.7200, -0.1262, -0.7975]]], device='cuda:0',
? ? ? ?grad_fn=<CudnnRnnBackward0>)
tensor([[[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225],
? ? ? ? ? [ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0',
? ? ? ?grad_fn=<ReshapeAliasBackward0>)
tensor([[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]],
? ? ? ? [[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]],
? ? ? ? [[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]],
? ? ? ? [[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]],
? ? ? ? [[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]],
? ? ? ? [[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]], device='cuda:0',
? ? ? ?grad_fn=<CudnnRnnBackward0>)
tensor([[[[-0.2606, ?0.5410, -0.2663, ?0.6418, -0.2902]],
? ? ? ? ?[[ 0.1367, ?0.7222, -0.3051, -0.6410, -0.3062]]],
? ? ? ? [[[ 0.2433, ?0.3287, -0.4809, -0.1782, -0.5582]],
? ? ? ? ?[[ 0.4824, -0.8529, ?0.7604, ?0.8508, -0.1902]]],
? ? ? ? [[[ 0.3939, -0.9160, ?0.5054, ?0.2949, -0.5225]],
? ? ? ? ?[[ 0.0533, ?0.4197, -0.7200, -0.1262, -0.7975]]]], device='cuda:0',
? ? ? ?grad_fn=<ReshapeAliasBackward0>)

總結(jié)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python實現(xiàn)遠程調(diào)用MetaSploit的方法

    Python實現(xiàn)遠程調(diào)用MetaSploit的方法

    這篇文章主要介紹了Python實現(xiàn)遠程調(diào)用MetaSploit的方法,是很有借鑒價值的一個技巧,需要的朋友可以參考下
    2014-08-08
  • 淺談Python3中strip()、lstrip()、rstrip()用法詳解

    淺談Python3中strip()、lstrip()、rstrip()用法詳解

    這篇文章主要介紹了淺談Python3中strip()、lstrip()、rstrip()用法詳解,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2019-04-04
  • 利用OpenCV給彩色圖像添加椒鹽噪聲的方法

    利用OpenCV給彩色圖像添加椒鹽噪聲的方法

    椒鹽噪聲是數(shù)字圖像中的常見噪聲,一般是圖像傳感器、傳輸信道及解碼處理等產(chǎn)生的黑白相間的亮暗點噪聲,椒鹽噪聲常由圖像切割產(chǎn)生,這篇文章主要給大家介紹了關(guān)于利用OpenCV給彩色圖像添加椒鹽噪聲的相關(guān)資料,需要的朋友可以參考下
    2021-10-10
  • Python中的index()方法使用教程

    Python中的index()方法使用教程

    這篇文章主要介紹了Python中的index()方法使用教程,是Python入門學(xué)習(xí)中的基礎(chǔ)知識,需要的朋友可以參考下
    2015-05-05
  • python繪制柱形圖的方法

    python繪制柱形圖的方法

    這篇文章主要為大家詳細介紹了python繪制柱形圖的方法,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2022-04-04
  • Python調(diào)用DeepSeek?API的案例詳細教程

    Python調(diào)用DeepSeek?API的案例詳細教程

    這篇文章主要為大家詳細介紹了以?Python?為例的調(diào)用?DeepSeek?API?的小白入門級詳細教程,文中的示例代碼講解詳細,感興趣的小伙伴可以了解下
    2025-02-02
  • python爬蟲如何解決圖片驗證碼

    python爬蟲如何解決圖片驗證碼

    這篇文章主要介紹了python爬蟲如何解決圖片驗證碼,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2021-02-02
  • Python?數(shù)據(jù)篩選功能實現(xiàn)

    Python?數(shù)據(jù)篩選功能實現(xiàn)

    這篇文章主要介紹了Python?數(shù)據(jù)篩選,無論是在數(shù)據(jù)分析還是數(shù)據(jù)挖掘的時候,數(shù)據(jù)篩選總會涉及到,這里我總結(jié)了一下python中列表,字典,數(shù)據(jù)框中一些常用的數(shù)據(jù)篩選的方法,需要的朋友可以參考下
    2023-04-04
  • python 五子棋如何獲得鼠標點擊坐標

    python 五子棋如何獲得鼠標點擊坐標

    這篇文章主要介紹了python 五子棋如何獲得鼠標點擊坐標,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-11-11
  • python實現(xiàn)掃描日志關(guān)鍵字的示例

    python實現(xiàn)掃描日志關(guān)鍵字的示例

    下面小編就為大家分享一篇python實現(xiàn)掃描日志關(guān)鍵字的示例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-04-04

最新評論