深入理解Pytorch中的torch. matmul()
torch.matmul()
語(yǔ)法
torch.matmul(input, other, *, out=None) → Tensor
作用
兩個(gè)張量的矩陣乘積
行為取決于張量的維度,如下所示:
- 如果兩個(gè)張量都是一維的,則返回點(diǎn)積(標(biāo)量)。
- 如果兩個(gè)參數(shù)都是二維的,則返回矩陣-矩陣乘積。
- 如果第一個(gè)參數(shù)是一維的,第二個(gè)參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個(gè) 1。在矩陣相乘之后,前置維度被移除。
- 如果第一個(gè)參數(shù)是二維的,第二個(gè)參數(shù)是一維的,則返回矩陣向量積。
- 如果兩個(gè)參數(shù)至少為一維且至少一個(gè)參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
- 如果第一個(gè)參數(shù)是一維的,則將 1 添加到其維度,以便批量矩陣相乘并在之后刪除。如果第二個(gè)參數(shù)是一維的,則將 1 附加到其維度以用于批量矩陣倍數(shù)并在之后刪除
- 非矩陣(即批次)維度是廣播的(因此必須是可廣播的)
- 例如,如果輸入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 張量
- 另一個(gè)是 ( k × n × n ) (k \times n \times n)(k×n×n)張量,
- out 將是一個(gè) ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 張量
請(qǐng)注意,廣播邏輯在確定輸入是否可廣播時(shí)僅查看批處理維度,而不是矩陣維度
例如
- 如果輸入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 張量
- 另一個(gè)是 ( k × m × p ) (k \times m \times p)(k×m×p) 張量
- 即使最后兩個(gè)維度(即矩陣維度)不同,這些輸入對(duì)于廣播也是有效的
- out 將是一個(gè) ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 張量
該運(yùn)算符支持 TensorFloat32。
在某些 ROCm 設(shè)備上,當(dāng)使用 float16 輸入時(shí),此模塊將使用不同的向后精度
舉例
情形1: 一維 * 一維
如果兩個(gè)張量都是一維的,則返回點(diǎn)積(標(biāo)量)
tensor1 = torch.Tensor([1,2,3]) tensor2 =torch.Tensor([4,5,6]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
ans = 1 * 4 + 2 * 5 + 3 * 6 = 32
情形2: 二維 * 二維
如果兩個(gè)參數(shù)都是二維的,則返回矩陣-矩陣乘積
也就是 正常的矩陣乘法 (m * n) * (n * k) = (m * k)
tensor1 = torch.Tensor([[1,2,3],[1,2,3]]) tensor2 =torch.Tensor([[4,5],[4,5],[4,5]]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
情形3: 一維 * 二維
如果第一個(gè)參數(shù)是一維的,第二個(gè)參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個(gè) 1
在矩陣相乘之后,前置維度被移除
tensor1 = torch.Tensor([1,2,3]) # 注意這里是一維 tensor2 =torch.Tensor([[4,5],[4,5],[4,5]]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
tensor1 = torch.Tensor([1,2,3])
修改為 tensor1 = torch.Tensor([[1,2,3]])
發(fā)現(xiàn)一個(gè)結(jié)果是[24., 30.]
一個(gè)是[[24., 30.]]
所以,當(dāng)一維 * 二維時(shí), 開(kāi)始變成 1 * m(一維的維度),也就是一個(gè)二維, 再進(jìn)行正常的矩陣運(yùn)算,得到[[24., 30.]]
, 然后再去掉開(kāi)始增加的一個(gè)維度,得到[24., 30.]
想象為二維 * 二維(前置維度為1),最后結(jié)果去掉一個(gè)維度即可
情形4: 二維 * 一維
如果第一個(gè)參數(shù)是二維的,第二個(gè)參數(shù)是一維的,則返回矩陣向量積
tensor1 =torch.Tensor([[4,5,6],[7,8,9]]) tensor2 = torch.Tensor([1,2,3]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
理解為:
- 把第一個(gè)二維中,想象為多個(gè)行向量
- 第二個(gè)一維想象為一個(gè)列向量
- 行向量與列向量進(jìn)行矩陣乘法,得到一個(gè)標(biāo)量
- 再按照行堆疊起來(lái)即可
情形5:兩個(gè)參數(shù)至少為一維且至少一個(gè)參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
第一個(gè)參數(shù)為N維,第二個(gè)參數(shù)為一維時(shí)
tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())
(4) 先添加一個(gè)維度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再刪除最后一個(gè)維度(添加的那個(gè))
得到結(jié)果(10 * 3)
tensor1 = torch.randn(10,2, 3, 4) # tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())
(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,刪1】
第一個(gè)參數(shù)為一維,第二個(gè)參數(shù)為二維時(shí)
tensor1 = torch.randn(4) tensor2 = torch.randn(10, 4, 3) print(torch.matmul(tensor1, tensor2).size())
tensor2 中第一個(gè)10理解為批次, 10個(gè)(4 * 3)
(1 * 4)與每個(gè)(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次為10,得到(10,3)
tensor1 = torch.randn(4) tensor2 = torch.randn(10,2, 4, 3) print(torch.matmul(tensor1, tensor2).size())
這里批次理解為[10, 2]即可
tensor1 = torch.randn(4) tensor2 = torch.randn(10,4, 2,4,1) print(torch.matmul(tensor1, tensor2).size())
個(gè)人理解:當(dāng)一個(gè)參數(shù)為一維時(shí),它要去匹配另一個(gè)參數(shù)的最后兩個(gè)維度(二維 * 二維)
比如上面的例子就是(1 * 4) 匹配 (4,1), 批次為(10,4,2)
高維 * 高維時(shí)
注:這不太好理解 … 感覺(jué)就是要找準(zhǔn)批次,再進(jìn)行乘法(靠感覺(jué)了 哈哈 離譜)
參考 https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
到此這篇關(guān)于深入理解Pytorch中的torch. matmul()的文章就介紹到這了,更多相關(guān)Pytorch torch. matmul()內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python3基礎(chǔ)語(yǔ)法知識(shí)點(diǎn)總結(jié)
在本篇文章里小編給大家分享的是一篇關(guān)于Python3基礎(chǔ)語(yǔ)法知識(shí)點(diǎn)總結(jié)內(nèi)容,有興趣的朋友們可以學(xué)習(xí)下。2021-05-05解決python3.6用cx_Oracle庫(kù)連接Oracle的問(wèn)題
這篇文章主要介紹了解決python3.6用cx_Oracle庫(kù)連接Oracle的問(wèn)題,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-12-12利用python操作SQLite數(shù)據(jù)庫(kù)及文件操作詳解
這篇文章主要給大家介紹了關(guān)于利用python操作SQLite數(shù)據(jù)庫(kù)及文件操作的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧。2017-09-09Python中的Joblib庫(kù)使用學(xué)習(xí)總結(jié)
這篇文章主要介紹了Python中的Joblib庫(kù)使用學(xué)習(xí)總結(jié),Joblib是一組在Python中提供輕量級(jí)流水線(xiàn)的工具,Joblib已被優(yōu)化得很快速,很健壯了,特別是在大數(shù)據(jù)上,并對(duì)numpy數(shù)組進(jìn)行了特定的優(yōu)化,需要的朋友可以參考下2023-08-08keras實(shí)現(xiàn)VGG16 CIFAR10數(shù)據(jù)集方式
這篇文章主要介紹了keras實(shí)現(xiàn)VGG16 CIFAR10數(shù)據(jù)集方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-07-07Python入門(mén)教程(三)Python語(yǔ)法解析
這篇文章主要介紹了Python入門(mén)教程(三)Python語(yǔ)法解析,Python是一門(mén)非常強(qiáng)大好用的語(yǔ)言,也有著易上手的特性,本文為入門(mén)教程,需要的朋友可以參考下2023-04-04