Pytorch 的 LSTM 模型的示例教程
1. 代碼
完整的源代碼:
import torch from torch import nn # 定義一個(gè)LSTM模型 class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # 初始化隱藏狀態(tài)h0, c0為全0向量 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) # 將輸入x和隱藏狀態(tài)(h0, c0)傳入LSTM網(wǎng)絡(luò) out, _ = self.lstm(x, (h0, c0)) # 取最后一個(gè)時(shí)間步的輸出作為L(zhǎng)STM網(wǎng)絡(luò)的輸出 out = self.fc(out[:, -1, :]) return out # 定義LSTM超參數(shù) input_size = 10 # 輸入特征維度 hidden_size = 32 # 隱藏單元數(shù)量 num_layers = 2 # LSTM層數(shù) output_size = 2 # 輸出類別數(shù)量 # 構(gòu)建一個(gè)隨機(jī)輸入x和對(duì)應(yīng)標(biāo)簽y x = torch.randn(64, 5, 10) # [batch_size, sequence_length, input_size] y = torch.randint(0, 2, (64,)) # 二分類任務(wù),標(biāo)簽為0或1 # 創(chuàng)建LSTM模型,并將輸入x傳入模型計(jì)算預(yù)測(cè)輸出 lstm = LSTM(input_size, hidden_size, num_layers, output_size) pred = lstm(x) # [batch_size, output_size] # 定義損失函數(shù)和優(yōu)化器,并進(jìn)行模型訓(xùn)練 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3) num_epochs = 100 for epoch in range(num_epochs): # 前向傳播計(jì)算損失函數(shù)值 pred = lstm(x) # 在每個(gè)epoch中重新計(jì)算預(yù)測(cè)輸出 loss = criterion(pred.squeeze(), y) # 反向傳播更新模型參數(shù) optimizer.zero_grad() loss.backward() optimizer.step() # 輸出每個(gè)epoch的訓(xùn)練損失 print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
2. 模型結(jié)構(gòu)分析
# 定義一個(gè)LSTM模型 class LSTM(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # 初始化隱藏狀態(tài)h0, c0為全0向量 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) # 將輸入x和隱藏狀態(tài)(h0, c0)傳入LSTM網(wǎng)絡(luò) out, _ = self.lstm(x, (h0, c0)) # 取最后一個(gè)時(shí)間步的輸出作為L(zhǎng)STM網(wǎng)絡(luò)的輸出 out = self.fc(out[:, -1, :]) return out
上述代碼定義了一個(gè)LSTM類,這個(gè)類可以用于完成一個(gè)基于LSTM的序列模型的搭建。
在初始化函數(shù)中,輸入的參數(shù)分別是輸入數(shù)據(jù)的特征維度(input_size),隱藏層的大?。╤idden_size),LSTM層數(shù)(num_layers)以及輸出數(shù)據(jù)的維度(output_size)。這里使用batch_first=True表示輸入數(shù)據(jù)的第一個(gè)維度是batch size,第二個(gè)維度是時(shí)間步長(zhǎng)和特征維度。
在forward函數(shù)中,首先初始化了LSTM網(wǎng)絡(luò)的隱藏狀態(tài)為全0向量,并且將其移動(dòng)到與輸入數(shù)據(jù)相同的設(shè)備上。然后調(diào)用了nn.LSTM函數(shù)進(jìn)行前向傳播操作,并且通過fc層將最后一個(gè)時(shí)間步的輸出映射為輸出的數(shù)據(jù),最后進(jìn)行了返回。
3. 代碼詳解
# 將輸入x和隱藏狀態(tài)(h0, c0)傳入LSTM網(wǎng)絡(luò) out, _ = self.lstm(x, (h0, c0))
這行代碼是利用 PyTorch 自帶的 LSTM 模塊處理輸入張量 x(形狀為 [batch_size, sequence_length, input_size])并得到 LSTM 層的輸出 out 和最終狀態(tài)。其中,h0 是 LSTM 層的初始隱藏狀態(tài),c0 是 LSTM 層的初始細(xì)胞狀態(tài)。
在代碼中,調(diào)用了 self.lstm(x, (h0, c0)) 函數(shù),該函數(shù)的返回值有兩個(gè):第一個(gè)返回值是 LSTM 層的輸出 out,其包含了所有時(shí)間步上的隱狀態(tài);第二個(gè)返回值是一個(gè)元組,包含了最后一個(gè)時(shí)間步的隱藏狀態(tài)和細(xì)胞狀態(tài),但我們用“_”丟棄了它。
因?yàn)閷?duì)于許多深度學(xué)習(xí)任務(wù)來說,只需要輸出序列的最后一個(gè)時(shí)間步的隱藏狀態(tài),而不需要每個(gè)時(shí)間步上的隱藏狀態(tài)。因此,這里我們只保留 LSTM 層的輸出 out,而忽略了 LSTM 層最后時(shí)間步的狀態(tài)。
最后,out 的形狀為 [batch_size, sequence_length, hidden_size],其中 hidden_size 是 LSTM 層輸出的隱藏狀態(tài)的維度大小。
x = torch.randn(64, 5, 10)
這行代碼創(chuàng)建了一個(gè)形狀為 (64, 5, 10) 的張量 x,它包含 64 個(gè)樣本,每個(gè)樣本具有 5 個(gè)特征維度和 10 個(gè)時(shí)間步。該張量的值是由均值為 0,標(biāo)準(zhǔn)差為 1 的正態(tài)分布隨機(jī)生成的。
torch.randn() 是 PyTorch 中生成服從標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)數(shù)的函數(shù)。它的輸入是張量的形狀,輸出是符合正態(tài)分布的張量。在本例中,形狀為 (64, 5, 10) 表示該張量包含 64 個(gè)樣本,每個(gè)樣本包含 5 個(gè)特征維度和 10 個(gè)時(shí)間步,每個(gè)元素都是服從標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)數(shù)。這種方式生成的隨機(jī)數(shù)可以用于初始化模型參數(shù)、生成噪音數(shù)據(jù)等許多深度學(xué)習(xí)應(yīng)用場(chǎng)景。
y = torch.randint(0, 2, (64,)) # 二分類任務(wù),標(biāo)簽為0或1
y = torch.randint(0, 2, (64,)) 是使用 PyTorch 庫中的 randint() 函數(shù)來生成一個(gè)64個(gè)元素的張量 y,張量的每個(gè)元素都是從區(qū)間 [0, 2) 中隨機(jī)生成的整數(shù)。
具體而言,torch.randint() 函數(shù)包含三個(gè)參數(shù),分別是 low、high 和 size。其中,low 和 high 分別表示隨機(jī)生成整數(shù)的區(qū)間為 [low, high),而 size 參數(shù)指定了生成的張量的形狀。
在上述代碼中,size=(64,) 表示生成的張量 y 的形狀為 64x1,即一個(gè)包含 64 個(gè)元素的一維張量,并且每個(gè)元素的值都在 [0, 2) 中隨機(jī)生成。這種形式的張量通常用于分類問題中的標(biāo)簽向量。在該任務(wù)中,一個(gè)標(biāo)簽通常由一個(gè)整數(shù)表示,因此可以采用使用 randint() 函數(shù)生成一個(gè)長(zhǎng)度為標(biāo)簽類別數(shù)的一維張量,其每個(gè)元素的取值為 0 或 1,表示對(duì)應(yīng)類別是否被選中。
# 創(chuàng)建LSTM模型,并將輸入x傳入模型計(jì)算預(yù)測(cè)輸出 lstm = LSTM(input_size, hidden_size, num_layers, output_size) pred = lstm(x) # [batch_size, output_size]
通過定義的LSTM類創(chuàng)建了一個(gè)LSTM模型,并將輸入x傳入模型進(jìn)行前向計(jì)算,得到了一個(gè)預(yù)測(cè)輸出pred,其形狀為[64, output_size],其中output_size是在LSTM初始化函數(shù)中指定的輸出數(shù)據(jù)的維度。
這段代碼演示了如何使用已經(jīng)構(gòu)建好的代碼搭建并訓(xùn)練一個(gè)基于LSTM的序列模型,并且展示了其中的一些關(guān)鍵步驟,包括數(shù)據(jù)輸入、模型創(chuàng)建以及前向計(jì)算。
到此這篇關(guān)于Pytorch 的 LSTM 模型的簡(jiǎn)單示例的文章就介紹到這了,更多相關(guān)Pytorch LSTM 模型內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python自動(dòng)截取需要區(qū)域,進(jìn)行圖像識(shí)別的方法
今天小編就為大家分享一篇python自動(dòng)截取需要區(qū)域,進(jìn)行圖像識(shí)別的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-05-05使用python檢測(cè)網(wǎng)頁文本內(nèi)容屏幕上的坐標(biāo)
在 Web 開發(fā)中,經(jīng)常需要對(duì)網(wǎng)頁上的文本內(nèi)容進(jìn)行處理和操作,有時(shí)候,我們可能需要知道某個(gè)特定文本在屏幕上的位置,以便進(jìn)行后續(xù)的操作,所以本文將介紹如何使用 Python 中的 Selenium 和 BeautifulSoup 庫來檢測(cè)網(wǎng)頁文本內(nèi)容在屏幕上的坐標(biāo),需要的朋友可以參考下2024-04-04Python使用atexit模塊實(shí)現(xiàn)Golang的defer功能
這篇文章主要為大家詳細(xì)介紹了Python如何使用atexit模塊實(shí)現(xiàn)Golang的defer功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-04-04Python深入分析@property裝飾器的應(yīng)用
這篇文章主要介紹了Python @property裝飾器的用法,在Python中,可以通過@property裝飾器將一個(gè)方法轉(zhuǎn)換為屬性,從而實(shí)現(xiàn)用于計(jì)算的屬性,下面文章圍繞主題展開更多相關(guān)詳情,感興趣的小伙伴可以參考一下2022-07-07python使用pandas庫導(dǎo)入并保存excel、csv格式文件數(shù)據(jù)
CSV格式文件很方便各種工具之間傳遞數(shù)據(jù),平時(shí)工作過程之中會(huì)將數(shù)據(jù)保存為CSV格式,這篇文章主要介紹了python使用pandas庫導(dǎo)入并保存excel、csv格式文件數(shù)據(jù)的相關(guān)資料,需要的朋友可以參考下2017-12-12詳解用Python進(jìn)行時(shí)間序列預(yù)測(cè)的7種方法
這篇文章主要介紹了詳解用Python進(jìn)行時(shí)間序列預(yù)測(cè)的7種方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-03-03