pytorch如何實(shí)現(xiàn)多個(gè)矩陣拼接
pytorch多個(gè)矩陣拼接
問題描述
在處理數(shù)據(jù)的時(shí)候遇到一個(gè)for循環(huán)中生成多個(gè)【max_len*max_len】的二維矩陣,現(xiàn)需要將這些矩陣在第一維上進(jìn)行堆疊,形成一個(gè)新的【batch * max_len * max_len】三維矩陣
實(shí)現(xiàn)過程
a = torch.ones(3, 3) ?# 假設(shè)生成的矩陣形狀為3*3 c = [] ?# 定義一個(gè)空列表用于存儲(chǔ)矩陣 for i in range(3): ? ? a = a ? ? c.append(a.unsqueeze(0)) # 使用cat方法可之間實(shí)現(xiàn)該操作 c = torch.cat(c, dim=0) ? print(c.size())
輸出c的形狀:
torch.Size([3, 3, 3])
pytorch中torch.cat()矩陣拼接的用法
深度學(xué)習(xí)模型里的輸出的東西還是有點(diǎn)搞的。torch.cat()的用處還是蠻大的。
下面直接舉例子理解。
一維拼接
import torch a = torch.Tensor([1, 2, 3]) b = a * 2 c = torch.cat((a, b), dim=0) ?# dim=-1為取最后一維。這里只有一維-1和0是一樣的 print(a.shape) print(c.shape) print(c)
二維拼接
dim就是選擇哪一維進(jìn)行拼接,dim=-1就表示最后一維進(jìn)行拼接,這個(gè)也很好理解,索引-1一般都指最后一個(gè)字符
a = torch.Tensor([[1, 2]]) b = a * 2 c1 = torch.cat((a, b), dim=0) c2 = torch.cat((a, b), dim=1) ?# 這里第二維是最后一維,dim=-1和dim=1是一樣的 print("a:", a) print("a.shape:", a.shape) print("c1:", c1) print("c1.shape:", c1.shape) print("c2:", c2) print("c2.shape:", c2.shape)
當(dāng)你使用pytorch深度學(xué)習(xí)模型時(shí),隱藏層不止一層,最好將所有的隱藏層都利用起來,那么就需要進(jìn)行隱藏層的拼接了。
假設(shè)隱藏層h_n.shape為(2,3,4)表示有2個(gè)隱藏層,batch_size為3(3個(gè)樣本一起訓(xùn)練),隱藏層大小為4。由于隱藏層都包含了一定的信息,那么我們都利用起來應(yīng)該效果比較好(聽學(xué)長說很多論文都證明過了),那么每個(gè)樣本對(duì)應(yīng)的隱藏層應(yīng)該都拼接起來用即2*4的大小。這樣就需要用到拼接了。
h_n = torch.randn(2, 3, 4) ?# 假設(shè)隱藏層 # 下面三種寫法是一個(gè)意思 feature_map = torch.cat([h_n[i] for i in range(h_n.shape[0])], dim=-1) ?# 索引第i個(gè)整元素,元素里剩下的維度缺省是全取的意思 feature_map1 = torch.cat([h_n[i, :, :] for i in range(h_n.shape[0])], dim=-1) feature_map2 = torch.cat([h_n[i] for i in range(h_n.shape[0])], dim=1) print(feature_map.shape) print(feature_map1.shape) print(feature_map2.shape)
隱藏層拼接完之后就可以放進(jìn)全連接層然后出結(jié)果了。
由于LSTM的現(xiàn)在時(shí)刻的輸出是前一個(gè)時(shí)刻的隱藏層和現(xiàn)在時(shí)刻的輸入經(jīng)過softmax得到的,而現(xiàn)在時(shí)刻的隱藏層是 現(xiàn)在時(shí)刻的輸出*tanh(現(xiàn)在時(shí)刻的細(xì)胞狀態(tài))得到的,現(xiàn)在時(shí)刻的隱藏層也是包含了現(xiàn)在輸入的信息的,因此直接放入全連接然后出結(jié)果就好了,至于模型的輸出可以不用,直接用隱藏層也是可以的吧?;蛘哒f隱藏層就相當(dāng)于包含著各自特征信息,輸出層也是基于隱藏層來的,因此我們深度學(xué)習(xí)模型里直接用隱藏層就是在直接用那些特征吧(強(qiáng)行理解一波)
用模型的輸出或者模型隱藏層應(yīng)該都是可以得出結(jié)果的,目前對(duì)我來說,效果應(yīng)該都差不多。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python使用functools實(shí)現(xiàn)注解同步方法
這篇文章主要介紹了Python使用functools實(shí)現(xiàn)注解同步方法,非常不錯(cuò),具有參考借鑒價(jià)值,需要的朋友可以參考下2018-02-02Python實(shí)現(xiàn)拉格朗日插值法的示例詳解
插值法是一種數(shù)學(xué)方法,用于在已知數(shù)據(jù)點(diǎn)(離散數(shù)據(jù))之間插入數(shù)據(jù),以生成連續(xù)的函數(shù)曲線,而格朗日插值法是一種多項(xiàng)式插值法。本文就來用Python實(shí)現(xiàn)拉格朗日插值法,希望對(duì)大家有所幫助2023-02-02使用python tkinter實(shí)現(xiàn)各種個(gè)樣的撩妹鼠標(biāo)拖尾效果
這篇文章主要介紹了使用python tkinter實(shí)現(xiàn)各種個(gè)樣的撩妹鼠標(biāo)拖尾效果,本文通過實(shí)例代碼,給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-09-09Python易忽視知識(shí)點(diǎn)小結(jié)
這篇文章主要介紹了Python易忽視知識(shí)點(diǎn),實(shí)例分析了Python中容易被忽視的常見操作技巧,需要的朋友可以參考下2015-05-05python可迭代類型遍歷過程中數(shù)據(jù)改變會(huì)不會(huì)報(bào)錯(cuò)
這篇文章主要介紹了python可迭代類型遍歷過程中數(shù)據(jù)改變會(huì)不會(huì)報(bào)錯(cuò)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-12-12python決策樹預(yù)測學(xué)生成績等級(jí)實(shí)現(xiàn)詳情
這篇文章主要為介紹了python決策樹預(yù)測學(xué)生成績等級(jí),使用決策樹完成學(xué)生成績等級(jí)預(yù)測,可選取部分或全部特征,分析參數(shù)對(duì)結(jié)果的影響,并進(jìn)行調(diào)參優(yōu)化,決策樹可視化進(jìn)行調(diào)參優(yōu)化分析2022-04-04Python編寫一個(gè)多線程的12306搶票程序的示例
對(duì)于很多人來說,搶購火車票人們成了一個(gè)令人頭疼的問題,本文主要介紹了Python編寫一個(gè)多線程的12306搶票程序的示例,具有一定的參考價(jià)值,感興趣的可以了解一下2023-09-09