Python Attention注意力機(jī)制的原理及應(yīng)用詳解
前言
Attention機(jī)制是深度學(xué)習(xí)中的一種技術(shù),特別是在自然語言處理(NLP)和計(jì)算機(jī)視覺領(lǐng)域中得到了廣泛的應(yīng)用。它的核心思想是模仿人類的注意力機(jī)制,即人類在處理信息時(shí)會(huì)集中注意力在某些關(guān)鍵部分上,而忽略其他不那么重要的信息。在機(jī)器學(xué)習(xí)模型中,這可以幫助模型更好地捕捉到輸入數(shù)據(jù)中的關(guān)鍵信息。
一、Attention機(jī)制的基本原理
1.輸入表示
在自然語言處理(NLP)任務(wù)中,輸入數(shù)據(jù)通常是文本形式的,我們需要將這些文本轉(zhuǎn)換為模型可以處理的數(shù)值形式。這個(gè)過程稱為嵌入(Embedding)。嵌入層將每個(gè)單詞映射到一個(gè)高維空間中的向量,這些向量被稱為詞向量。詞向量能夠捕捉單詞的語義信息,并且可以被神經(jīng)網(wǎng)絡(luò)處理。
# 定義一個(gè)簡(jiǎn)單的嵌入層 class EmbeddingLayer(nn.Module): def __init__(self, vocab_size, embed_dim): super(EmbeddingLayer, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) def forward(self, x): return self.embedding(x) # 隨機(jī)生成一個(gè)輸入序列 input_seq = torch.randint(0, vocab_size, (32, 50)) # (batch_size, seq_len) # 獲取輸入表示 input_repr = embedding_layer(input_seq)
在代碼中,我們定義了一個(gè)EmbeddingLayer
類,它包含一個(gè)nn.Embedding
層,用于將輸入的索引轉(zhuǎn)換為對(duì)應(yīng)的詞向量。然后,我們生成一個(gè)隨機(jī)的輸入序列input_seq
,它模擬了一個(gè)批量大小為32,序列長度為50的文本數(shù)據(jù)。通過嵌入層,我們將這些索引轉(zhuǎn)換為詞向量,得到輸入表示input_repr
。
2.計(jì)算注意力權(quán)重
注意力機(jī)制允許模型在處理序列數(shù)據(jù)時(shí),動(dòng)態(tài)地聚焦于當(dāng)前步驟最相關(guān)的信息。在自注意力(Self-Attention)中,每個(gè)元素都會(huì)計(jì)算與其他所有元素的關(guān)聯(lián)程度,這通過計(jì)算查詢(Q)、鍵(K)和值(V)的線性變換來實(shí)現(xiàn)。
class Attention(nn.Module): def __init__(self, embed_dim): super(Attention, self).__init__() self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) def forward(self, x): Q = self.query(x) K = self.key(x) V = self.value(x) attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(embed_dim) attention_weights = F.softmax(attention_scores, dim=-1) return attention_weights
在這段代碼中,我們定義了一個(gè)Attention
類,它包含三個(gè)線性層,分別用于計(jì)算Q、K和V。然后,我們通過矩陣乘法和softmax函數(shù)計(jì)算注意力權(quán)重,這些權(quán)重表示序列中每個(gè)元素對(duì)當(dāng)前元素的重要性。
3.加權(quán)求和
一旦我們有了注意力權(quán)重,我們就可以使用它們來加權(quán)求和序列中的元素,從而生成一個(gè)綜合了所有元素信息的表示。
def weighted_sum(attention_weights, input_repr): return torch.matmul(attention_weights, input_repr)
這個(gè)簡(jiǎn)單的函數(shù)weighted_sum
接受注意力權(quán)重和輸入表示作為輸入,然后通過矩陣乘法計(jì)算加權(quán)求和,得到一個(gè)綜合了序列中所有元素信息的新表示。
4.輸出
最后,我們使用一個(gè)輸出層將加權(quán)求和得到的表示轉(zhuǎn)換為最終的輸出,這可以是分類任務(wù)的類別概率,也可以是其他任務(wù)的預(yù)測(cè)結(jié)果。
class OutputLayer(nn.Module): def __init__(self, embed_dim, output_dim): super(OutputLayer, self).__init__() self.fc = nn.Linear(embed_dim, output_dim) def forward(self, x): return self.fc(x)
在這個(gè)代碼段中,我們定義了一個(gè)OutputLayer
類,它包含一個(gè)線性層,用于將模型的內(nèi)部表示映射到輸出空間。例如,在分類任務(wù)中,我們可以將嵌入維度的表示映射到類別數(shù)量的輸出空間,并通過softmax函數(shù)或其他激活函數(shù)得到最終的預(yù)測(cè)概率。
5.實(shí)例代碼
以下是使用Python和PyTorch實(shí)現(xiàn)上述內(nèi)容的示例代碼。這段代碼將展示如何使用一個(gè)簡(jiǎn)單的Transformer模型來處理文本數(shù)據(jù),包括輸入表示、計(jì)算注意力權(quán)重、加權(quán)求和以及輸出。
import torch import torch.nn as nn import torch.nn.functional as F class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super(TransformerBlock, self).__init__() self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout) self.ffn = nn.Sequential( nn.Linear(embed_dim, 4 * embed_dim), nn.GELU(), nn.Linear(4 * embed_dim, embed_dim), ) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # 輸入表示 # x: (seq_len, batch_size, embed_dim) attn_output, _ = self.attn(x, x, x) # 自注意力,輸入和輸出都是x attn_output = self.dropout(attn_output) x = self.norm1(x + attn_output) # 加權(quán)求和和殘差連接 # 前饋網(wǎng)絡(luò) ffn_output = self.ffn(x) ffn_output = self.dropout(ffn_output) x = self.norm2(x + ffn_output) # 加權(quán)求和和殘差連接 return x class TextTransformer(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, num_layers, dropout=0.1): super(TextTransformer, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.positional_encoding = nn.Parameter(torch.randn(1, 1, embed_dim)) self.encoder = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, dropout) for _ in range(num_layers)]) self.fc_out = nn.Linear(embed_dim, vocab_size) # 假設(shè)是分類任務(wù) def forward(self, x): # 輸入表示 embeds = self.embedding(x) # (batch_size, seq_len, embed_dim) embeds = embeds + self.positional_encoding[:, :embeds.size(1), :] # 添加位置編碼 embeds = embeds.transpose(0, 1) # (seq_len, batch_size, embed_dim) # 計(jì)算注意力權(quán)重和加權(quán)求和 out = self.encoder(embeds) # 輸出 out = out.transpose(0, 1) # (batch_size, seq_len, embed_dim) out = self.fc_out(out[:, -1, :]) # 假設(shè)只取序列的最后一個(gè)向量進(jìn)行分類 return out # 模型參數(shù) vocab_size = 10000 # 詞匯表大小 embed_dim = 256 # 嵌入層維度 num_heads = 8 # 注意力頭數(shù) num_layers = 6 # Transformer層數(shù) # 實(shí)例化模型 model = TextTransformer(vocab_size, embed_dim, num_heads, num_layers) # 隨機(jī)生成一個(gè)輸入序列 input_seq = torch.randint(0, vocab_size, (32, 100)) # (batch_size, seq_len) # 前向傳播 output = model(input_seq) print(output.shape) # 應(yīng)該輸出 (batch_size, vocab_size)
這段代碼首先定義了一個(gè)TransformerBlock
類,它包含了自注意力機(jī)制和前饋網(wǎng)絡(luò),然后定義了一個(gè)TextTransformer
類,它包含了嵌入層、位置編碼、編碼器和輸出層。在TextTransformer
的前向傳播中,我們首先將輸入序列轉(zhuǎn)換為嵌入表示,然后通過Transformer編碼器進(jìn)行處理,最后通過一個(gè)全連接層輸出結(jié)果。這個(gè)例子展示了如何使用Transformer模型處理文本數(shù)據(jù),并進(jìn)行分類任務(wù)。
二、Attention機(jī)制的類型
1.Soft Attention
這種類型的注意力機(jī)制會(huì)輸出一個(gè)概率分布,每個(gè)輸入元素都有一個(gè)對(duì)應(yīng)的權(quán)重,這些權(quán)重的和為1。Soft attention通??梢晕⒎郑虼丝梢杂糜谔荻认陆?。Soft Attention輸出一個(gè)概率分布,可以通過梯度下降進(jìn)行優(yōu)化。
import torch import torch.nn as nn import torch.nn.functional as F class SoftAttention(nn.Module): def __init__(self, embed_dim): super(SoftAttention, self).__init__() self.weight = nn.Parameter(torch.randn(embed_dim, 1)) def forward(self, x): # x: (batch_size, seq_len, embed_dim) scores = torch.matmul(x, self.weight).squeeze(-1) # (batch_size, seq_len) weights = F.softmax(scores, dim=-1) # Softmax to get probabilities return weights # 示例使用 embed_dim = 128 soft_attn = SoftAttention(embed_dim) input_seq = torch.randn(32, 50, embed_dim) # (batch_size, seq_len, embed_dim) attention_weights = soft_attn(input_seq) print("Soft Attention Weights:", attention_weights.sum(dim=1)) # 應(yīng)該接近于1
2.Hard Attention
與soft attention不同,hard attention會(huì)隨機(jī)或確定性地選擇一個(gè)輸入元素,并只關(guān)注這個(gè)元素。Hard attention通常不可微分,因此訓(xùn)練時(shí)可能需要使用強(qiáng)化學(xué)習(xí)或變分方法。Hard Attention隨機(jī)選擇一個(gè)輸入元素,這里我們使用一個(gè)簡(jiǎn)單的采樣策略。
import torch class HardAttention(nn.Module): def __init__(self, embed_dim): super(HardAttention, self).__init__() def forward(self, x): # x: (batch_size, seq_len, embed_dim) probs = torch.rand(x.size(0), x.size(1), device=x.device) _, idx = torch.topk(probs, k=1, dim=1) selected = torch.gather(x, 1, idx.unsqueeze(-1).expand(-1, -1, x.size(-1))) return selected.squeeze(1) # 示例使用 hard_attn = HardAttention(embed_dim) selected_elements = hard_attn(input_seq) print("Hard Attention Selected Elements:", selected_elements.shape) # (batch_size, embed_dim)
3.Self-Attention
即自注意力機(jī)制,這是一種特殊的注意力機(jī)制,它允許輸入序列中的元素相互之間計(jì)算注意力權(quán)重,這在Transformer模型中得到了廣泛應(yīng)用。Self-Attention允許輸入序列中的元素相互之間計(jì)算注意力權(quán)重。
class SelfAttention(nn.Module): def __init__(self, embed_dim): super(SelfAttention, self).__init__() self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) def forward(self, x): # x: (batch_size, seq_len, embed_dim) Q = self.query(x) K = self.key(x) V = self.value(x) attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(embed_dim) attention_weights = F.softmax(attention_scores, dim=-1) output = torch.matmul(attention_weights, V) return output, attention_weights # 示例使用 self_attn = SelfAttention(embed_dim) output, weights = self_attn(input_seq) print("Self Attention Output:", output.shape) # (batch_size, seq_len, embed_dim)
4.Multi-Head Attention
在Transformer模型中,為了捕捉不同子空間中的信息,會(huì)使用多頭注意力機(jī)制,即并行地運(yùn)行多個(gè)自注意力機(jī)制,然后將結(jié)果合并。Multi-Head Attention并行地運(yùn)行多個(gè)自注意力機(jī)制,然后將結(jié)果合并。
class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" self.query = nn.Linear(embed_dim, embed_dim) self.key = nn.Linear(embed_dim, embed_dim) self.value = nn.Linear(embed_dim, embed_dim) self.fc_out = nn.Linear(embed_dim, embed_dim) def forward(self, x): # x: (batch_size, seq_len, embed_dim) batch_size, seq_len, embed_dim = x.size() Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.head_dim) attention_weights = F.softmax(attention_scores, dim=-1) output = torch.matmul(attention_weights, V).transpose(1, 2).contiguous() output = output.view(batch_size, seq_len, embed_dim) output = self.fc_out(output) return output # 示例使用 num_heads = 8 multi_head_attn = MultiHeadAttention(embed_dim, num_heads) multi_head_output = multi_head_attn(input_seq) print("Multi-Head Attention Output:", multi_head_output.shape) # (batch_size, seq_len, embed_dim)
Soft Attention和Self-Attention可以直接用于梯度下降優(yōu)化,而Hard Attention由于其不可微分的特性,可能需要特殊的訓(xùn)練技巧。Multi-Head Attention則通過并行處理捕捉更豐富的信息。
三、Attention機(jī)制的應(yīng)用
1.機(jī)器翻譯
機(jī)器翻譯是注意力機(jī)制最著名的應(yīng)用之一。在這個(gè)任務(wù)中,模型需要將一種語言(源語言)的文本轉(zhuǎn)換為另一種語言(目標(biāo)語言)的文本。注意力機(jī)制在這里的作用是在生成目標(biāo)語言的每個(gè)單詞時(shí),動(dòng)態(tài)地聚焦于源語言中相關(guān)的部分,這有助于捕捉長距離依賴關(guān)系,并提高翻譯的準(zhǔn)確性和流暢性。
import torch import torch.nn as nn import torch.optim as optim class Encoder(nn.Module): def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.embedding = nn.Embedding(input_dim, emb_dim) self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout) self.dropout = nn.Dropout(dropout) def forward(self, src): embedded = self.dropout(self.embedding(src)) outputs, (hidden, cell) = self.rnn(embedded) return hidden, cell class Attention(nn.Module): def __init__(self, enc_hid_dim, dec_hid_dim): super().__init__() self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) self.v = nn.Linear(dec_hid_dim, 1, bias=False) def forward(self, hidden, encoder_outputs): hidden = hidden.repeat(encoder_outputs.shape[0], 1).transpose(0, 1) encoder_outputs = encoder_outputs.transpose(0, 1) attn_energies = self.score(hidden, encoder_outputs) return F.softmax(attn_energies, dim=-1) def score(self, hidden, encoder_outputs): energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2))) energy = self.v(energy).squeeze(2) return energy class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.output_dim = output_dim self.embedding = nn.Embedding(output_dim, emb_dim) self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout) self.attention = Attention(hid_dim, hid_dim) self.fc_out = nn.Linear(hid_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, cell, encoder_outputs): input = input.unsqueeze(0) embedded = self.dropout(self.embedding(input)) attn_weights = self.attention(hidden, encoder_outputs) context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) rnn_input = torch.cat((embedded, context), dim=2) output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) output = output.squeeze(0) out = self.fc_out(output) return out, hidden, cell # 假設(shè)參數(shù) input_dim = 1000 # 源語言詞匯表大小 output_dim = 1000 # 目標(biāo)語言詞匯表大小 emb_dim = 256 # 嵌入層維度 hid_dim = 512 # 隱藏層維度 n_layers = 2 # LSTM層數(shù) dropout = 0.1 # Dropout # 實(shí)例化模型 encoder = Encoder(input_dim, emb_dim, hid_dim, n_layers, dropout) decoder = Decoder(output_dim, emb_dim, hid_dim, n_layers, dropout) # 假設(shè)輸入 src = torch.randint(0, input_dim, (10, 32)) # (seq_len, batch_size) input = torch.randint(0, output_dim, (1, 32)) # (seq_len, batch_size) # 前向傳播 hidden, cell = encoder(src) output, hidden, cell = decoder(input, hidden, cell, src) print("Translation Output:", output.shape) # (batch_size, output_dim)
在示例代碼中,我們定義了一個(gè)基于注意力的Seq2Seq模型,它由一個(gè)編碼器和一個(gè)解碼器組成。編碼器讀取源語言文本,并輸出一個(gè)上下文向量和隱藏狀態(tài)。解碼器則使用這個(gè)上下文向量來生成目標(biāo)語言文本,同時(shí)更新隱藏狀態(tài)。注意力機(jī)制通過計(jì)算源語言文本中每個(gè)單詞的重要性,并將這些信息合并到解碼器的每一步中,從而允許模型在生成每個(gè)單詞時(shí)“回顧”源語言文本的相關(guān)部分。
2.文本摘要
在自動(dòng)文本摘要任務(wù)中,模型需要從長文本中提取關(guān)鍵信息,并生成一個(gè)簡(jiǎn)短的摘要。注意力機(jī)制可以幫助模型識(shí)別哪些句子或短語對(duì)于理解全文內(nèi)容最為重要,從而在生成摘要時(shí)保留這些關(guān)鍵信息。
import torch import torch.nn as nn import torch.optim as optim class Encoder(nn.Module): def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.embedding = nn.Embedding(input_dim, emb_dim) self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout) self.dropout = nn.Dropout(dropout) def forward(self, src): embedded = self.dropout(self.embedding(src)) outputs, (hidden, cell) = self.rnn(embedded) return hidden, cell class Attention(nn.Module): def __init__(self, enc_hid_dim, dec_hid_dim): super().__init__() self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) self.v = nn.Linear(dec_hid_dim, 1, bias=False) def forward(self, hidden, encoder_outputs): hidden = hidden.repeat(encoder_outputs.shape[0], 1).transpose(0, 1) encoder_outputs = encoder_outputs.transpose(0, 1) attn_energies = self.score(hidden, encoder_outputs) return F.softmax(attn_energies, dim=-1) def score(self, hidden, encoder_outputs): energy = torch.tanh(self.attn(torch.cat([hidden, encoder_outputs], dim=2))) energy = self.v(energy).squeeze(2) return energy class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout): super().__init__() self.output_dim = output_dim self.embedding = nn.Embedding(output_dim, emb_dim) self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout) self.attention = Attention(hid_dim, hid_dim) self.fc_out = nn.Linear(hid_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, cell, encoder_outputs): input = input.unsqueeze(0) embedded = self.dropout(self.embedding(input)) attn_weights = self.attention(hidden, encoder_outputs) context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) rnn_input = torch.cat((embedded, context), dim=2) output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell)) output = output.squeeze(0) out = self.fc_out(output) return out, hidden, cell # 假設(shè)參數(shù) input_dim = 1000 # 源語言詞匯表大小 output_dim = 1000 # 目標(biāo)語言詞匯表大小 emb_dim = 256 # 嵌入層維度 hid_dim = 512 # 隱藏層維度 n_layers = 2 # LSTM層數(shù) dropout = 0.1 # Dropout # 實(shí)例化模型 encoder = Encoder(input_dim, emb_dim, hid_dim, n_layers, dropout) decoder = Decoder(output_dim, emb_dim, hid_dim, n_layers, dropout) # 假設(shè)輸入 src = torch.randint(0, input_dim, (10, 32)) # (seq_len, batch_size) input = torch.randint(0, output_dim, (1, 32)) # (seq_len, batch_size) # 前向傳播 hidden, cell = encoder(src) output, hidden, cell = decoder(input, hidden, cell, src) print("Translation Output:", output.shape) # (batch_size, output_dim)
雖然示例代碼沒有詳細(xì)展示,但可以想象,一個(gè)基于注意力的文本摘要模型會(huì)有一個(gè)編碼器來處理輸入文本,并生成一系列隱藏狀態(tài)。然后,一個(gè)解碼器會(huì)使用這些隱藏狀態(tài)和注意力權(quán)重來生成摘要,同時(shí)關(guān)注輸入文本中與當(dāng)前生成摘要最相關(guān)的部分。這樣,生成的摘要不僅包含了原文的核心信息,而且更加緊湊和連貫。
3.圖像識(shí)別
在圖像識(shí)別任務(wù)中,模型的目標(biāo)是識(shí)別圖像中的對(duì)象。注意力機(jī)制可以幫助模型集中注意力在圖像中的關(guān)鍵特征上,例如人臉的眼睛或汽車的輪子,這些特征對(duì)于識(shí)別任務(wù)至關(guān)重要。
import torchvision.models as models class AttentionCNN(nn.Module): def __init__(self): super().__init__() self.cnn = models.resnet18(pretrained=True) self.fc = nn.Linear(512, 1000) # 假設(shè)有1000個(gè)類別 def forward(self, x): x = self.cnn(x) # 假設(shè)我們添加一個(gè)簡(jiǎn)單的注意力層 attention_weights = torch.sigmoid(self.cnn.fc.weight) x = torch.sum(x * attention_weights, dim=1) x = self.fc(x) return x # 實(shí)例化模型 attention_cnn = AttentionCNN() # 假設(shè)輸入 input_image = torch.randn(32, 3, 224, 224) # (batch_size, channels, height, width) # 前向傳播 output = attention_cnn(input_image) print("Image Recognition Output:", output.shape) # (batch_size, num_classes)
在示例代碼中,我們定義了一個(gè)帶有簡(jiǎn)單注意力層的CNN模型。這個(gè)注意力層通過學(xué)習(xí)圖像中不同區(qū)域的重要性,為每個(gè)特征分配權(quán)重。這樣,模型就可以更加關(guān)注于對(duì)分類任務(wù)最重要的特征,而不是平等對(duì)待圖像中的所有像素。這種方法可以提高模型對(duì)圖像中關(guān)鍵信息的敏感性,從而提高識(shí)別的準(zhǔn)確性。
4.語音識(shí)別
語音識(shí)別是將語音信號(hào)轉(zhuǎn)換為文本的任務(wù)。在這個(gè)任務(wù)中,模型需要理解語音中的語義信息,并將其轉(zhuǎn)換為書面語言。注意力機(jī)制可以幫助模型在處理語音信號(hào)時(shí),關(guān)注那些攜帶重要信息的部分,例如特定的音素或單詞。
class SpeechRecognitionModel(nn.Module): def __init__(self, input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout): super().__init__() self.rnn = nn.LSTM(input_dim, emb_dim, n_layers, dropout=dropout, batch_first=True) self.attention = Attention(emb_dim, emb_dim) self.fc_out = nn.Linear(emb_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # x: (batch_size, seq_len, input_dim) outputs, (hidden, cell) = self.rnn(x) attn_weights = self.attention(hidden, outputs) context = torch.bmm(attn_weights, outputs) output = self.fc_out(context.squeeze(1)) return output # 假設(shè)參數(shù) input_dim = 128 # 特征維度 output_dim = 1000 # 詞匯表大小 # 實(shí)例化模型 speech_recognition = SpeechRecognitionModel(input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout) # 假設(shè)輸入 speech_signal = torch.randn(32, 100, input_dim) # (batch_size, seq_len, input_dim) # 前向傳播 output = speech_recognition(speech_signal) print("Speech Recognition Output:", output.shape) # (batch_size, output_dim)
在示例代碼中,我們定義了一個(gè)基于注意力的RNN模型,用于處理語音信號(hào)。模型的RNN部分處理序列化的語音特征,而注意力機(jī)制則幫助模型在生成每個(gè)單詞時(shí),關(guān)注語音信號(hào)中最相關(guān)的部分。這樣,模型可以更準(zhǔn)確地捕捉到語音中的語義信息,并將其轉(zhuǎn)換為正確的文本輸出。
以上就是Python Attention注意力機(jī)制的原理及應(yīng)用詳解的詳細(xì)內(nèi)容,更多關(guān)于Python Attention注意力機(jī)制的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python plt 利用subplot 實(shí)現(xiàn)在一張畫布同時(shí)畫多張圖
這篇文章主要介紹了Python plt 利用subplot 實(shí)現(xiàn)在一張畫布同時(shí)畫多張圖,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2021-02-02由面試題加深對(duì)Django的認(rèn)識(shí)理解
這篇文章主要介紹了由面試題加深對(duì)Django的認(rèn)識(shí),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07Python使用PyQt5實(shí)現(xiàn)與DeepSeek聊天的圖形化小軟件
在?PyQt5?中,菜單欄(QMenuBar)、工具欄(QToolBar)和狀態(tài)欄(QStatusBar)是?QMainWindow?提供的標(biāo)準(zhǔn)控件,用于幫助用戶更好地與應(yīng)用程序交互,所以本文給大家介紹了Python使用PyQt5實(shí)現(xiàn)與DeepSeek聊天的圖形化小軟件,需要的朋友可以參考下2025-03-03淺談python的輸入輸出,注釋,基本數(shù)據(jù)類型
這篇文章主要介紹了python的輸入輸出,注釋,基本數(shù)據(jù)類型,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-04-04Python去除字符串中的標(biāo)點(diǎn)符號(hào)的最優(yōu)方式
在Python編程中,去除字符串標(biāo)點(diǎn)符號(hào)是一項(xiàng)常見任務(wù),關(guān)鍵在于文本分析和數(shù)據(jù)清洗,Python提供了多種方法,包括使用str.replace()、str.translate()結(jié)合str.maketrans(),以及使用正則表達(dá)式,另外,可以利用string模塊中的punctuation屬性快速實(shí)現(xiàn)2024-09-09linux環(huán)境部署清華大學(xué)大模型最新版 chatglm2-6b 圖文教程
這篇文章主要介紹了linux環(huán)境部署清華大學(xué)大模型最新版 chatglm2-6b ,結(jié)合實(shí)例形式詳細(xì)分析了Linux環(huán)境下chatglm2-6b部署相關(guān)操作步驟與注意事項(xiàng),需要的朋友可以參考下2023-07-07