欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python使用pytorch動手實(shí)現(xiàn)LSTM模塊

 更新時(shí)間:2022年07月27日 08:53:21   作者:qyhyzard  
這篇文章主要介紹了Python使用pytorch動手實(shí)現(xiàn)LSTM模塊,LSTM是RNN中一個(gè)較為流行的網(wǎng)絡(luò)模塊。主要包括輸入,輸入門,輸出門,遺忘門,激活函數(shù),全連接層(Cell)和輸出

LSTM 簡介:

LSTM是RNN中一個(gè)較為流行的網(wǎng)絡(luò)模塊。主要包括輸入,輸入門,輸出門,遺忘門,激活函數(shù),全連接層(Cell)和輸出。

其結(jié)構(gòu)如下:

上述公式不做解釋,我們只要大概記得以下幾個(gè)點(diǎn)就可以了:

  • 當(dāng)前時(shí)刻LSTM模塊的輸入有來自當(dāng)前時(shí)刻的輸入值,上一時(shí)刻的輸出值,輸入值和隱含層輸出值,就是一共有四個(gè)輸入值,這意味著一個(gè)LSTM模塊的輸入量是原來普通全連接層的四倍左右,計(jì)算量多了許多。
  • 所謂的門就是前一時(shí)刻的計(jì)算值輸入到sigmoid激活函數(shù)得到一個(gè)概率值,這個(gè)概率值決定了當(dāng)前輸入的強(qiáng)弱程度。 這個(gè)概率值和當(dāng)前輸入進(jìn)行矩陣乘法得到經(jīng)過門控處理后的實(shí)際值。
  • 門控的激活函數(shù)都是sigmoid,范圍在(0,1),而輸出輸出單元的激活函數(shù)都是tanh,范圍在(-1,1)。

Pytorch實(shí)現(xiàn)如下:

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import math
class NaiveLSTM(nn.Module):
    """Naive LSTM like nn.LSTM"""
    def __init__(self, input_size: int, hidden_size: int):
        super(NaiveLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # input gate
        self.w_ii = Parameter(Tensor(hidden_size, input_size))
        self.w_hi = Parameter(Tensor(hidden_size, hidden_size))
        self.b_ii = Parameter(Tensor(hidden_size, 1))
        self.b_hi = Parameter(Tensor(hidden_size, 1))

        # forget gate
        self.w_if = Parameter(Tensor(hidden_size, input_size))
        self.w_hf = Parameter(Tensor(hidden_size, hidden_size))
        self.b_if = Parameter(Tensor(hidden_size, 1))
        self.b_hf = Parameter(Tensor(hidden_size, 1))

        # output gate
        self.w_io = Parameter(Tensor(hidden_size, input_size))
        self.w_ho = Parameter(Tensor(hidden_size, hidden_size))
        self.b_io = Parameter(Tensor(hidden_size, 1))
        self.b_ho = Parameter(Tensor(hidden_size, 1))

        # cell
        self.w_ig = Parameter(Tensor(hidden_size, input_size))
        self.w_hg = Parameter(Tensor(hidden_size, hidden_size))
        self.b_ig = Parameter(Tensor(hidden_size, 1))
        self.b_hg = Parameter(Tensor(hidden_size, 1))

        self.reset_weigths()

    def reset_weigths(self):
        """reset weights
        """
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            init.uniform_(weight, -stdv, stdv)

    def forward(self, inputs: Tensor, state: Tuple[Tensor]) \
        -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        """Forward
        Args:
            inputs: [1, 1, input_size]
            state: ([1, 1, hidden_size], [1, 1, hidden_size])
        """
#         seq_size, batch_size, _ = inputs.size()

        if state is None:
            h_t = torch.zeros(1, self.hidden_size).t()
            c_t = torch.zeros(1, self.hidden_size).t()
        else:
            (h, c) = state
            h_t = h.squeeze(0).t()
            c_t = c.squeeze(0).t()

        hidden_seq = []

        seq_size = 1
        for t in range(seq_size):
            x = inputs[:, t, :].t()
            # input gate
            i = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +
                              self.b_hi)
            # forget gate
            f = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +
                              self.b_hf)
            # cell
            g = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t
                           + self.b_hg)
            # output gate
            o = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +
                              self.b_ho)

            c_next = f * c_t + i * g
            h_next = o * torch.tanh(c_next)
            c_next_t = c_next.t().unsqueeze(0)
            h_next_t = h_next.t().unsqueeze(0)
            hidden_seq.append(h_next_t)

        hidden_seq = torch.cat(hidden_seq, dim=0)
        return hidden_seq, (h_next_t, c_next_t)

def reset_weigths(model):
    """reset weights
    """
    for weight in model.parameters():
        init.constant_(weight, 0.5)
### test 
inputs = torch.ones(1, 1, 10)
h0 = torch.ones(1, 1, 20)
c0 = torch.ones(1, 1, 20)
print(h0.shape, h0)
print(c0.shape, c0)
print(inputs.shape, inputs)
# test naive_lstm with input_size=10, hidden_size=20
naive_lstm = NaiveLSTM(10, 20)
reset_weigths(naive_lstm)
output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))
print(hn1.shape, cn1.shape, output1.shape)
print(hn1)
print(cn1)
print(output1)

對比官方實(shí)現(xiàn):

# Use official lstm with input_size=10, hidden_size=20
lstm = nn.LSTM(10, 20)
reset_weigths(lstm)
output2, (hn2, cn2) = lstm(inputs, (h0, c0))
print(hn2.shape, cn2.shape, output2.shape)
print(hn2)
print(cn2)
print(output2)

可以看到與官方的實(shí)現(xiàn)有些許的不同,但是輸出的結(jié)果仍舊一致。

到此這篇關(guān)于Python使用pytorch動手實(shí)現(xiàn)LSTM模塊的文章就介紹到這了,更多相關(guān)Python實(shí)現(xiàn)LSTM模塊內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Boston數(shù)據(jù)集預(yù)測放假及應(yīng)用優(yōu)缺點(diǎn)評估

    Boston數(shù)據(jù)集預(yù)測放假及應(yīng)用優(yōu)缺點(diǎn)評估

    這篇文章主要為大家介紹了Boston數(shù)據(jù)集預(yù)測放假及應(yīng)用優(yōu)缺點(diǎn)評估,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-10-10
  • 查看Python安裝路徑以及安裝包路徑小技巧

    查看Python安裝路徑以及安裝包路徑小技巧

    這篇文章主要介紹了查看Python安裝路徑以及安裝包路徑小技巧,本文使用直接在命令行運(yùn)行Python代碼的方法檢測安裝路徑以及安裝包路徑,需要的朋友可以參考下
    2015-04-04
  • Python單鏈表的簡單實(shí)現(xiàn)方法

    Python單鏈表的簡單實(shí)現(xiàn)方法

    這篇文章主要介紹了Python單鏈表的簡單實(shí)現(xiàn)方法,包括定義所需的字段及具體實(shí)現(xiàn)代碼的分析,需要的朋友可以參考下
    2014-09-09
  • Python入門之實(shí)例方法、類方法和靜態(tài)方法的區(qū)別講解

    Python入門之實(shí)例方法、類方法和靜態(tài)方法的區(qū)別講解

    這篇文章主要介紹了Python入門之實(shí)例方法、類方法和靜態(tài)方法的區(qū)別講解,實(shí)例方法是在創(chuàng)建了類的實(shí)例之后才能被調(diào)用的方法,類方法是在不需要創(chuàng)建類的實(shí)例的情況下就可以調(diào)用的方法,最后,靜態(tài)方法是與類和類的實(shí)例都沒有綁定關(guān)系的方法,需要的朋友可以參考下
    2023-10-10
  • 詳解Python中的文件操作

    詳解Python中的文件操作

    這篇文章主要介紹了Python中文件操作的相關(guān)資料,幫助大家更好的理解和學(xué)習(xí)python,感興趣的朋友可以了解下
    2021-01-01
  • spyder常用快捷鍵(分享)

    spyder常用快捷鍵(分享)

    下面小編就為大家?guī)硪黄猻pyder常用快捷鍵(分享)。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2017-07-07
  • python如何在文件中部插入信息

    python如何在文件中部插入信息

    這篇文章主要介紹了python如何在文件中部插入信息問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • python實(shí)現(xiàn)多個(gè)視頻文件合成畫中畫效果

    python實(shí)現(xiàn)多個(gè)視頻文件合成畫中畫效果

    這篇文章主要為大家詳細(xì)介紹了python實(shí)現(xiàn)多個(gè)視頻文件合成畫中畫效果,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2021-08-08
  • python PrettyTable模塊的安裝與簡單應(yīng)用

    python PrettyTable模塊的安裝與簡單應(yīng)用

    prettyTable 是一款很簡潔但是功能強(qiáng)大的第三方模塊,主要是將輸入的數(shù)據(jù)轉(zhuǎn)化為格式化的形式來輸出,這篇文章主要介紹了python PrettyTable模塊的安裝與簡單應(yīng)用,感興趣的小伙伴們可以參考一下
    2019-01-01
  • python中的錯(cuò)誤處理

    python中的錯(cuò)誤處理

    異常是指程序中的例外,違例情況。異常機(jī)制是指程序出現(xiàn)錯(cuò)誤后,程序的處理方法。當(dāng)出現(xiàn)錯(cuò)誤后,程序的執(zhí)行流程發(fā)生改變,程序的控制權(quán)轉(zhuǎn)移到異常處理。
    2016-04-04

最新評論