解讀torch.nn.GRU的輸入及輸出示例
我們有時會看到GRU中輸入的參數(shù)有時是一個,但是有時又有兩個。這難免會讓人們感到疑惑,那么這些參數(shù)到底是什么呢。
一、輸入到GRU的參數(shù)
輸入的參數(shù)有兩個,分別是input和h_0。
Inputs: input, h_0
①input的shape
The shape of input:(seq_len, batch, input_size) : tensor containing the feature of the input sequence. The input can also be a packed variable length sequence。
See functorch.nn.utils.rnn.pack_padded_sequencefor details.
②h_0的shape
從下面的解釋中也可以看出,這個參數(shù)可以不提供,那么就默認為0.
The shape of h_0:(num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch.
Defaults to zero if not provided. If the RNN is bidirectional num_directions should be 2, else it should be 1.
綜上,可以只輸入一個參數(shù)。當(dāng)輸入兩個參數(shù)的時候,那么第二個參數(shù)相當(dāng)于是一個隱含層的輸出。
為了便于理解,下面是一幅圖:
二、GRU返回的數(shù)據(jù)
輸出有兩個,分別是output和h_n
①output
output 的shape是:(seq_len, batch, num_directions * hidden_size): tensor containing the output features h_t from the last layer of the GRU, for each t.
If a class:torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.
For the unpacked case, the directions can be separated using output.view(seq_len, batch, num_directions, hidden_size), with forward and backward being direction 0 and 1 respectively.
Similarly, the directions can be separated in the packed case.
②h_n
h_n的shape是:(num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len
Like output, the layers can be separated using
h_n.view(num_layers, num_directions, batch, hidden_size).
三、代碼示例
數(shù)據(jù)的shape是[batch,seq_len,emb_dim]
RNN接收輸入的數(shù)據(jù)的shape是[seq_len,batch,emb_dim]
即前兩個維度調(diào)換就行了。
可以知道,加入批處理的時候一次處理128個句子,每個句子中有5個單詞,那么上圖中展示的input_data的shape是:[128,5,emb_dim]。
結(jié)合代碼分析,本例子將演示有1個句子和5個句子的情況。假設(shè)每個句子中有9個單詞,所以seq_len=9,并且每個單詞對應(yīng)的emb_dim=3,所以對應(yīng)數(shù)據(jù)的shape是: [batch,9,3],由于輸入到RNN中數(shù)據(jù)格式的格式,所以為[9,batch,3]
import torch import torch.nn as nn emb_dim = 3 hidden_dim = 2 rnn = nn.GRU(emb_dim,hidden_dim) #rnn = nn.GRU(9,1,3) print(type(rnn)) tensor1 = torch.tensor([[-0.5502, -0.1920, 1.1845], [-0.8003, 2.0783, 0.0175], [ 0.6761, 0.7183, -1.0084], [ 0.9514, 1.4772, -0.2271], [-1.0146, 0.7912, 0.2003], [-0.5502, -0.1920, 1.1845], [-0.8003, 2.0783, 0.0175], [ 0.1718, 0.1070, 0.4255], [-2.6727, -1.5680, -0.8369]]) tensor2 = torch.tensor([[-0.5502, -0.1920]]) # 假設(shè)input只有一個句子,那么batch為1 print('--------------batch=1時------------') data = tensor1.unsqueeze(0) h_0 = tensor2[0].unsqueeze(0).unsqueeze(0) print('data.shape: [batch,seq_len,emb_dim]',data.shape) print('') input = data.transpose(0,1) print('input.shape: [seq_len,batch,emb_dim]',input.shape) print('h_0.shape: [1,batch,hidden_dim]',h_0.shape) print('') # 輸入到rnn中 output,h_n = rnn(input,h_0) print('output.shape: [seq_len,batch,hidden_dim]',output.shape) print('h_n.shape: [1,batch,hidden_dim]',h_n.shape) # 假設(shè)input中有5個句子,所以,batch = 5 print('\n--------------batch=5時------------') data = tensor1.unsqueeze(0).repeat(5,1,1) # 由于batch為5 h_0 = tensor2[0].unsqueeze(0).repeat(1,5,1) # 由于batch為5 print('data.shape: [batch,seq_len,emb_dim]',data.shape) print('') input = data.transpose(0,1) print('input.shape: [seq_len,batch,emb_dim]',input.shape) print('h_0.shape: [1,batch,hidden_dim]',h_0.shape) print('') # 輸入到rnn中 output,h_n = rnn(input,h_0) print('output.shape: [seq_len,batch,hidden_dim]',output.shape) print('h_n.shape: [1,batch,hidden_dim]',h_n.shape)
四、輸出
<class ‘torch.nn.modules.rnn.GRU’>
--------------batch=1時------------
data.shape: [batch,seq_len,emb_dim] torch.Size([1, 9, 3])input.shape: [seq_len,batch,emb_dim] torch.Size([9, 1, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 1, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])--------------batch=5時------------
data.shape: [batch,seq_len,emb_dim] torch.Size([5, 9, 3])input.shape: [seq_len,batch,emb_dim] torch.Size([9, 5, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 5, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python環(huán)境搭建以及Python與PyCharm安裝詳細圖文教程
PyCharm是一種PythonIDE,帶有一整套可以幫助用戶在使用Python語言開發(fā)時提高其效率的工具,這篇文章主要給大家介紹了關(guān)于Python環(huán)境搭建以及Python與PyCharm安裝的詳細圖文教程,需要的朋友可以參考下2024-03-03Python實現(xiàn)全角半角字符互轉(zhuǎn)的方法
大家都知道在自然語言處理過程中,全角、半角的的不一致會導(dǎo)致信息抽取不一致,因此需要統(tǒng)一。這篇文章通過示例代碼給大家詳細的介紹了Python實現(xiàn)全角半角字符互轉(zhuǎn)的方法,有需要的朋友們可以參考借鑒,下面跟著小編一起學(xué)習(xí)學(xué)習(xí)吧。2016-11-11python 解決動態(tài)的定義變量名,并給其賦值的方法(大數(shù)據(jù)處理)
今天小編就為大家分享一篇python 解決動態(tài)的定義變量名,并給其賦值的方法(大數(shù)據(jù)處理),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-11-11Python使用thread模塊實現(xiàn)多線程的操作
線程(Threads)是操作系統(tǒng)提供的一種輕量級的執(zhí)行單元,可以在一個進程內(nèi)并發(fā)執(zhí)行多個任務(wù),每個線程都有自己的執(zhí)行上下文,包括棧、寄存器和程序計數(shù)器,本文給大家介紹了Python使用thread模塊實現(xiàn)多線程的操作,需要的朋友可以參考下2024-10-10