欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch深度學(xué)習(xí)LSTM從input輸入到Linear輸出

 更新時(shí)間:2022年05月11日 10:13:43   作者:Cyril_KI  
這篇文章主要為大家介紹了PyTorch深度學(xué)習(xí)LSTM從input輸入到Linear輸出深入理解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

LSTM介紹

關(guān)于LSTM的具體原理,可以參考:

http://www.dbjr.com.cn/article/178582.htm

http://www.dbjr.com.cn/article/178423.htm

系列文章:

PyTorch搭建雙向LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測(cè)

PyTorch搭建LSTM實(shí)現(xiàn)多變量多步長(zhǎng)時(shí)序負(fù)荷預(yù)測(cè)

PyTorch搭建LSTM實(shí)現(xiàn)多變量時(shí)序負(fù)荷預(yù)測(cè)

PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測(cè)

LSTM參數(shù)

關(guān)于nn.LSTM的參數(shù),官方文檔給出的解釋為:

總共有七個(gè)參數(shù),其中只有前三個(gè)是必須的。由于大家普遍使用PyTorch的DataLoader來(lái)形成批量數(shù)據(jù),因此batch_first也比較重要。LSTM的兩個(gè)常見(jiàn)的應(yīng)用場(chǎng)景為文本處理和時(shí)序預(yù)測(cè),因此下面對(duì)每個(gè)參數(shù)我都會(huì)從這兩個(gè)方面來(lái)進(jìn)行具體解釋。

  • input_size:在文本處理中,由于一個(gè)單詞沒(méi)法參與運(yùn)算,因此我們得通過(guò)Word2Vec來(lái)對(duì)單詞進(jìn)行嵌入表示,將每一個(gè)單詞表示成一個(gè)向量,此時(shí)input_size=embedding_size。
  • 比如每個(gè)句子中有五個(gè)單詞,每個(gè)單詞用一個(gè)100維向量來(lái)表示,那么這里input_size=100;
  • 在時(shí)間序列預(yù)測(cè)中,比如需要預(yù)測(cè)負(fù)荷,每一個(gè)負(fù)荷都是一個(gè)單獨(dú)的值,都可以直接參與運(yùn)算,因此并不需要將每一個(gè)負(fù)荷表示成一個(gè)向量,此時(shí)input_size=1。
  • 但如果我們使用多變量進(jìn)行預(yù)測(cè),比如我們利用前24小時(shí)每一時(shí)刻的[負(fù)荷、風(fēng)速、溫度、壓強(qiáng)、濕度、天氣、節(jié)假日信息]來(lái)預(yù)測(cè)下一時(shí)刻的負(fù)荷,那么此時(shí)input_size=7。
  • hidden_size:隱藏層節(jié)點(diǎn)個(gè)數(shù)。可以隨意設(shè)置。
  • num_layers:層數(shù)。nn.LSTMCell與nn.LSTM相比,num_layers默認(rèn)為1。
  • batch_first:默認(rèn)為False,意義見(jiàn)后文。

Inputs

關(guān)于LSTM的輸入,官方文檔給出的定義為:

可以看到,輸入由兩部分組成:input、(初始的隱狀態(tài)h_0,初始的單元狀態(tài)c_0)

其中input:

input(seq_len, batch_size, input_size)
  • seq_len:在文本處理中,如果一個(gè)句子有7個(gè)單詞,則seq_len=7;在時(shí)間序列預(yù)測(cè)中,假設(shè)我們用前24個(gè)小時(shí)的負(fù)荷來(lái)預(yù)測(cè)下一時(shí)刻負(fù)荷,則seq_len=24。
  • batch_size:一次性輸入LSTM中的樣本個(gè)數(shù)。在文本處理中,可以一次性輸入很多個(gè)句子;在時(shí)間序列預(yù)測(cè)中,也可以一次性輸入很多條數(shù)據(jù)。
  • input_size:見(jiàn)前文。

(h_0, c_0):

h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)

h_0和c_0的shape一致。

  • num_directions:如果是雙向LSTM,則num_directions=2;否則num_directions=1。
  • num_layers:見(jiàn)前文。
  • batch_size:見(jiàn)前文。
  • hidden_size:見(jiàn)前文。

Outputs

關(guān)于LSTM的輸出,官方文檔給出的定義為:

可以看到,輸出也由兩部分組成:otput、(隱狀態(tài)h_n,單元狀態(tài)c_n)

其中output的shape為:

output(seq_len, batch_size, num_directions * hidden_size)

h_n和c_n的shape保持不變,參數(shù)解釋見(jiàn)前文。

batch_first

如果在初始化LSTM時(shí)令batch_first=True,那么input和output的shape將由:

input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)

變?yōu)椋?/p>

input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)

即batch_size提前。

案例

簡(jiǎn)單搭建一個(gè)LSTM如下所示:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1 # 單向LSTM
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)
    def forward(self, input_seq):
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        seq_len = input_seq.shape[1] # (5, 30)
        # input(batch_size, seq_len, input_size)
        input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
        output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
        pred = self.linear(output) # pred(150, 1)
        pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
        pred = pred[:, -1, :]  # (5, 1)
        return pred

其中定義模型的代碼為:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)

我們加上具體的數(shù)字:

self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

再看前向傳播:

def forward(self, input_seq):
    h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
    seq_len = input_seq.shape[1]  # (5, 30)
    # input(batch_size, seq_len, input_size)
    input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)
    # output(batch_size, seq_len, num_directions * hidden_size)
    output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)
    output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size)  # (5 * 30, 64)
    pred = self.linear(output) # (150, 1)
    pred = pred.view(self.batch_size, seq_len, -1)  # (5, 30, 1)
    pred = pred[:, -1, :]  # (5, 1)
    return pred

假設(shè)用前30個(gè)預(yù)測(cè)下一個(gè),則seq_len=30,batch_size=5,由于設(shè)置了batch_first=True,因此,輸入到LSTM中的input的shape應(yīng)該為:

input(batch_size, seq_len, input_size) = input(5, 30, 1)

但實(shí)際上,經(jīng)過(guò)DataLoader處理后的input_seq為:

input_seq(batch_size, seq_len) = input_seq(5, 30)

(5, 30)表示一共5條數(shù)據(jù),每條數(shù)據(jù)的維度都為30。為了匹配LSTM的輸入,我們需要對(duì)input_seq的shape進(jìn)行變換:

input_seq = input_seq.view(self.batch_size, seq_len, 1)  # (5, 30, 1)

然后將input_seq送入LSTM:

output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)

根據(jù)前文,output的shape為:

output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)

全連接層的定義為:

self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

因此,我們需要將output的第二維度變換為64(150, 64):

output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)

然后將output送入全連接層:

pred = self.linear(output) # pred(150, 1)

得到的預(yù)測(cè)值shape為(150, 1)。我們需要將其進(jìn)行還原,變成(5, 30, 1):

pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)

在用DataLoader處理了數(shù)據(jù)后,得到的input_seq和label的shape分別為:

input_seq(batch_size, seq_len) = input_seq(5, 30)label(batch_size, output_size) = label(5, 1)

由于輸出是輸入右移,我們只需要取pred第二維度(time)中的最后一個(gè)數(shù)據(jù):

pred = pred[:, -1, :] # (5, 1)

這樣,我們就得到了預(yù)測(cè)值,然后與label求loss,然后再反向更新參數(shù)即可。

時(shí)間序列預(yù)測(cè)的一個(gè)真實(shí)案例請(qǐng)見(jiàn):PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列預(yù)測(cè)(負(fù)荷預(yù)測(cè))

以上就是PyTorch深度學(xué)習(xí)LSTM從input輸入到Linear輸出的詳細(xì)內(nèi)容,更多關(guān)于LSTM input輸入Linear輸出的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Django rest framework基本介紹與代碼示例

    Django rest framework基本介紹與代碼示例

    這篇文章主要介紹了Django rest framework基本介紹與代碼示例,簡(jiǎn)單敘述了rest framework的一些用處,可選擇的相關(guān)軟件包,然后分享了一個(gè)簡(jiǎn)單的模型支持的API的例子,小編覺(jué)得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • 大家都說(shuō)好用的Python命令行庫(kù)click的使用

    大家都說(shuō)好用的Python命令行庫(kù)click的使用

    這篇文章主要介紹了大家都說(shuō)好用的Python命令行庫(kù)click的使用,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-11-11
  • Python實(shí)現(xiàn)改變與矩形橡膠的線條的顏色代碼示例

    Python實(shí)現(xiàn)改變與矩形橡膠的線條的顏色代碼示例

    這篇文章主要介紹了Python實(shí)現(xiàn)改變與矩形橡膠的線條的顏色代碼示例,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • Python利用3D引擎做一個(gè)太陽(yáng)系行星模擬器

    Python利用3D引擎做一個(gè)太陽(yáng)系行星模擬器

    Python有一個(gè)不錯(cuò)的3D引擎——Ursina。本文就來(lái)利用Ursina這一3D引擎做一個(gè)太陽(yáng)系行星模擬器,感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下
    2023-01-01
  • python打包exe文件并隱藏執(zhí)行CMD命令窗口問(wèn)題

    python打包exe文件并隱藏執(zhí)行CMD命令窗口問(wèn)題

    這篇文章主要介紹了python打包exe文件并隱藏執(zhí)行CMD命令窗口問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-01-01
  • Python專用方法與迭代機(jī)制實(shí)例分析

    Python專用方法與迭代機(jī)制實(shí)例分析

    這篇文章主要介紹了Python專用方法與迭代機(jī)制,包括類的私有方法、專有方法、模塊私有對(duì)象、迭代__iter__()方法的對(duì)象等,需要的朋友可以參考下
    2014-09-09
  • python計(jì)算機(jī)視覺(jué)OpenCV入門講解

    python計(jì)算機(jī)視覺(jué)OpenCV入門講解

    這篇文章主要介紹了python計(jì)算機(jī)視覺(jué)OpenCV入門講解,關(guān)于圖像處理的相關(guān)簡(jiǎn)單操作,包括讀入圖像、顯示圖像及圖像相關(guān)理論知識(shí)
    2022-06-06
  • 深入了解NumPy 高級(jí)索引

    深入了解NumPy 高級(jí)索引

    這篇文章主要介紹了NumPy 高級(jí)索引的相關(guān)資料,文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下
    2020-07-07
  • Python?Pygame實(shí)戰(zhàn)之打磚塊小游戲

    Python?Pygame實(shí)戰(zhàn)之打磚塊小游戲

    打磚塊最早是由雅達(dá)利公司開(kāi)發(fā)的一款獨(dú)立游戲,也是無(wú)數(shù)人的童年記憶。本文將利用Python中的Pygame模塊制作經(jīng)典的打磚塊游戲,需要的可以參考一下
    2022-02-02
  • 用Python中的turtle模塊畫圖兩只小羊方法

    用Python中的turtle模塊畫圖兩只小羊方法

    在本片文章里小編給大家分享了關(guān)于用Python中的turtle模塊畫圖兩只小羊的實(shí)例操作方法,需要的朋友們學(xué)習(xí)下。
    2019-04-04

最新評(píng)論