pytorch中nn.RNN()匯總
nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)
參數(shù)說明
- input_size輸入特征的維度, 一般rnn中輸入的是詞向量,那么 input_size 就等于一個(gè)詞向量的維度
- hidden_size隱藏層神經(jīng)元個(gè)數(shù),或者也叫輸出的維度(因?yàn)閞nn輸出為各個(gè)時(shí)間步上的隱藏狀態(tài))
- num_layers網(wǎng)絡(luò)的層數(shù)
- nonlinearity激活函數(shù)
- bias是否使用偏置
- batch_first輸入數(shù)據(jù)的形式,默認(rèn)是 False,就是這樣形式,(seq(num_step), batch, input_dim),也就是將序列長(zhǎng)度放在第一位,batch 放在第二位
- dropout是否應(yīng)用dropout, 默認(rèn)不使用,如若使用將其設(shè)置成一個(gè)0-1的數(shù)字即可
- birdirectional是否使用雙向的 rnn,默認(rèn)是 False
- 注意某些參數(shù)的默認(rèn)值在標(biāo)題中已注明
輸入輸出shape
- input_shape = [時(shí)間步數(shù), 批量大小, 特征維度] = [num_steps(seq_length), batch_size, input_dim]
- 在前向計(jì)算后會(huì)分別返回輸出和隱藏狀態(tài)h,其中輸出指的是隱藏層在各個(gè)時(shí)間步上計(jì)算并輸出的隱藏狀態(tài),它們通常作為后續(xù)輸出層的輸?。需要強(qiáng)調(diào)的是,該“輸出”本身并不涉及輸出層計(jì)算,形狀為(時(shí)間步數(shù), 批量大小, 隱藏單元個(gè)數(shù));隱藏狀態(tài)指的是隱藏層在最后時(shí)間步的隱藏狀態(tài):當(dāng)隱藏層有多層時(shí),每?層的隱藏狀態(tài)都會(huì)記錄在該變量中;對(duì)于像?短期記憶(LSTM),隱藏狀態(tài)是?個(gè)元組(h, c),即hidden state和cell state(此處普通rnn只有一個(gè)值)隱藏狀態(tài)h的形狀為(層數(shù), 批量大小,隱藏單元個(gè)數(shù))
代碼
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, ) # 定義模型, 其中vocab_size = 1027, hidden_size = 256
num_steps = 35 batch_size = 2 state = None # 初始隱藏層狀態(tài)可以不定義 X = torch.rand(num_steps, batch_size, vocab_size) Y, state_new = rnn_layer(X, state) print(Y.shape, len(state_new), state_new.shape)
輸出
torch.Size([35, 2, 256]) 1 torch.Size([1, 2, 256])
具體計(jì)算過程
H t = i n p u t ∗ W x h + H t − 1 ∗ W h h + b i a s H_t = input * W_{xh} + H_{t-1} * W_{hh} + bias Ht?=input∗Wxh?+Ht−1?∗Whh?+bias[batch_size, input_dim] * [input_dim, num_hiddens] + [batch_size, num_hiddens] *[num_hiddens, num_hiddens] +bias
可以發(fā)現(xiàn)每個(gè)隱藏狀態(tài)形狀都是[batch_size, num_hiddens], 起始輸出也是一樣的
注意:上面為了方便假設(shè)num_step=1
GRU/LSTM等參數(shù)同上面RNN
到此這篇關(guān)于pytorch中nn.RNN()總結(jié)的文章就介紹到這了,更多相關(guān)pytorch nn.RNN()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Pytorch 解決自定義子Module .cuda() tensor失敗的問題
這篇文章主要介紹了Pytorch 解決自定義子Module .cuda() tensor失敗的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06關(guān)于Tensorflow和Keras版本對(duì)照及環(huán)境安裝
這篇文章主要介紹了關(guān)于Tensorflow和Keras版本對(duì)照及環(huán)境安裝方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-08-08Python實(shí)現(xiàn)的數(shù)據(jù)結(jié)構(gòu)與算法之基本搜索詳解
這篇文章主要介紹了Python實(shí)現(xiàn)的數(shù)據(jù)結(jié)構(gòu)與算法之基本搜索,詳細(xì)分析了Python順序搜索、二分搜索的使用技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-04-04詳解Selenium如何使用input標(biāo)簽上傳文件完整流程
這篇文章主要介紹了詳解Selenium如何使用input標(biāo)簽上傳文件完整流程,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-05-05PyTorch模型轉(zhuǎn)TensorRT是怎么實(shí)現(xiàn)的?
今天給大家?guī)淼氖顷P(guān)于Python的相關(guān)知識(shí),文章圍繞著PyTorch模型轉(zhuǎn)TensorRT是怎么實(shí)現(xiàn)的展開,文中有非常詳細(xì)的介紹及代碼示例,需要的朋友可以參考下2021-06-06pycharm引入其他目錄的包報(bào)錯(cuò),import報(bào)錯(cuò)的解決
這篇文章主要介紹了pycharm引入其他目錄的包報(bào)錯(cuò),import報(bào)錯(cuò)的解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-08-08python鏈表的基礎(chǔ)概念和基礎(chǔ)用法詳解
這篇文章主要為大家詳細(xì)介紹了python鏈表的基礎(chǔ)概念和基礎(chǔ)用法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2022-05-05python多線程高級(jí)鎖condition簡(jiǎn)單用法示例
這篇文章主要介紹了python多線程高級(jí)鎖condition簡(jiǎn)單用法,結(jié)合實(shí)例形式分析了condition對(duì)象常用方法及相關(guān)使用技巧,需要的朋友可以參考下2019-11-11OpenCV結(jié)合selenium實(shí)現(xiàn)滑塊驗(yàn)證碼
本文主要介紹了OpenCV結(jié)合selenium實(shí)現(xiàn)滑塊驗(yàn)證碼,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-08-08