Pytorch從0實現(xiàn)Transformer的實踐
摘要
With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.
隨著時序預測的不斷發(fā)展,Transformer類模型憑借強大的優(yōu)勢,在CV、NLP領域逐漸取代傳統(tǒng)模型。其中Informer在長時序預測上遠超傳統(tǒng)的RNN模型,Swin Transformer在圖像識別上明顯強于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領域的必然要求。本文將用Pytorch框架,實現(xiàn)Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡。
一、構(gòu)造數(shù)據(jù)
1.1 句子長度
# 關于word embedding,以序列建模為例 # 輸入句子有兩個,第一個長度為2,第二個長度為4 src_len = torch.tensor([2, 4]).to(torch.int32) # 目標句子有兩個。第一個長度為4, 第二個長度為3 tgt_len = torch.tensor([4, 3]).to(torch.int32) print(src_len) print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4, 第二個長度為3
1.2 生成句子
用隨機數(shù)生成句子,用0填充空白位置,保持所有句子長度一致
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) print(src_seq) print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應的也是一個數(shù)字,只有這樣才便于處理。
1.3 生成字典
在該字典中,總共有8個字(行),每個字對應8維向量(做了簡化了的)。注意在實際應用中,應當有幾十萬個字,每個字可能有512個維度。
# 構(gòu)造word embedding src_embedding_table = nn.Embedding(9, model_dim) tgt_embedding_table = nn.Embedding(9, model_dim) # 輸入單詞的字典 print(src_embedding_table) # 目標單詞的字典 print(tgt_embedding_table)
字典中,需要留一個維度給class token,故是9行。
1.4 得到向量化的句子
通過字典取出1.2
中得到的句子
# 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
該階段總程序
import torch # 句子長度 src_len = torch.tensor([2, 4]).to(torch.int32) tgt_len = torch.tensor([4, 3]).to(torch.int32) # 構(gòu)造句子,用0填充空白處 src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) # 構(gòu)造字典 src_embedding_table = nn.Embedding(9, 8) tgt_embedding_table = nn.Embedding(9, 8) # 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1 計算括號內(nèi)的值
# 得到分子pos的值 pos_mat = torch.arange(4).reshape((-1, 1)) # 得到分母值 i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8) print(pos_mat) print(i_mat)
2.2 得到位置編碼
# 初始化位置編碼矩陣 pe_embedding_table = torch.zeros(4, 8) # 得到偶數(shù)行位置編碼 pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat) # 得到奇數(shù)行位置編碼 pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat) pe_embedding = nn.Embedding(4, 8) # 設置位置編碼不可更新參數(shù) pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False) print(pe_embedding.weight)
三、多頭注意力
3.1 self mask
有些位置是空白用0填充的,訓練時不希望被這些位置所影響,那么就需要用到self mask。self mask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結(jié)果。
3.1.1 得到有效位置矩陣
# 得到有效位置矩陣 vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2) valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2)) print(valid_encoder_pos_matrix)
3.1.2 得到無效位置矩陣
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) print(mask_encoder_self_attention)
True
代表需要對該位置mask
3.1.3 得到mask矩陣
用極小數(shù)填充需要被mask的位置
# 初始化mask矩陣 score = torch.randn(2, max(src_len), max(src_len)) # 用極小數(shù)填充 mask_score = score.masked_fill(mask_encoder_self_attention, -1e9) print(mask_score)
算其softmat
mask_score_softmax = F.softmax(mask_score) print(mask_score_softmax)
可以看到,已經(jīng)達到預期效果
到此這篇關于Pytorch從0實現(xiàn)Transformer的實踐的文章就介紹到這了,更多相關Pytorch Transformer內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python實現(xiàn)根據(jù)指定端口探測服務器/模塊部署的方法
這篇文章主要介紹了Python根據(jù)指定端口探測服務器/模塊部署的方法,非常具有實用價值,需要的朋友可以參考下2014-08-08Python簡單實現(xiàn)圖片轉(zhuǎn)字符畫的實例項目
這篇文章主要介紹了Python簡單實現(xiàn)圖片轉(zhuǎn)字符畫的實例項目,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-04-04Python2 Selenium元素定位的實現(xiàn)(8種)
這篇文章主要介紹了Python2 Selenium元素定位的實現(xiàn),小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-02-02利用4行Python代碼監(jiān)測每一行程序的運行時間和空間消耗
這篇文章主要介紹了如何使用4行Python代碼監(jiān)測每一行程序的運行時間和空間消耗,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-04-04