深入理解Pytorch中的torch. matmul()
torch.matmul()
語法
torch.matmul(input, other, *, out=None) → Tensor
作用
兩個張量的矩陣乘積
行為取決于張量的維度,如下所示:
- 如果兩個張量都是一維的,則返回點積(標量)。
- 如果兩個參數(shù)都是二維的,則返回矩陣-矩陣乘積。
- 如果第一個參數(shù)是一維的,第二個參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個 1。在矩陣相乘之后,前置維度被移除。
- 如果第一個參數(shù)是二維的,第二個參數(shù)是一維的,則返回矩陣向量積。
- 如果兩個參數(shù)至少為一維且至少一個參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
- 如果第一個參數(shù)是一維的,則將 1 添加到其維度,以便批量矩陣相乘并在之后刪除。如果第二個參數(shù)是一維的,則將 1 附加到其維度以用于批量矩陣倍數(shù)并在之后刪除
- 非矩陣(即批次)維度是廣播的(因此必須是可廣播的)
- 例如,如果輸入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 張量
- 另一個是 ( k × n × n ) (k \times n \times n)(k×n×n)張量,
- out 將是一個 ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 張量
請注意,廣播邏輯在確定輸入是否可廣播時僅查看批處理維度,而不是矩陣維度
例如
- 如果輸入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 張量
- 另一個是 ( k × m × p ) (k \times m \times p)(k×m×p) 張量
- 即使最后兩個維度(即矩陣維度)不同,這些輸入對于廣播也是有效的
- out 將是一個 ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 張量
該運算符支持 TensorFloat32。
在某些 ROCm 設備上,當使用 float16 輸入時,此模塊將使用不同的向后精度
舉例
情形1: 一維 * 一維
如果兩個張量都是一維的,則返回點積(標量)
tensor1 = torch.Tensor([1,2,3]) tensor2 =torch.Tensor([4,5,6]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
ans = 1 * 4 + 2 * 5 + 3 * 6 = 32
情形2: 二維 * 二維
如果兩個參數(shù)都是二維的,則返回矩陣-矩陣乘積
也就是 正常的矩陣乘法 (m * n) * (n * k) = (m * k)
tensor1 = torch.Tensor([[1,2,3],[1,2,3]]) tensor2 =torch.Tensor([[4,5],[4,5],[4,5]]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
情形3: 一維 * 二維
如果第一個參數(shù)是一維的,第二個參數(shù)是二維的,為了矩陣乘法的目的,在它的維數(shù)前面加上一個 1
在矩陣相乘之后,前置維度被移除
tensor1 = torch.Tensor([1,2,3]) # 注意這里是一維 tensor2 =torch.Tensor([[4,5],[4,5],[4,5]]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
tensor1 = torch.Tensor([1,2,3])
修改為 tensor1 = torch.Tensor([[1,2,3]])
發(fā)現(xiàn)一個結果是[24., 30.]
一個是[[24., 30.]]
所以,當一維 * 二維時, 開始變成 1 * m(一維的維度),也就是一個二維, 再進行正常的矩陣運算,得到[[24., 30.]]
, 然后再去掉開始增加的一個維度,得到[24., 30.]
想象為二維 * 二維(前置維度為1),最后結果去掉一個維度即可
情形4: 二維 * 一維
如果第一個參數(shù)是二維的,第二個參數(shù)是一維的,則返回矩陣向量積
tensor1 =torch.Tensor([[4,5,6],[7,8,9]]) tensor2 = torch.Tensor([1,2,3]) ans = torch.matmul(tensor1, tensor2) print('tensor1 : ', tensor1) print('tensor2 : ', tensor2) print('ans :', ans) print('ans.size :', ans.size())
理解為:
- 把第一個二維中,想象為多個行向量
- 第二個一維想象為一個列向量
- 行向量與列向量進行矩陣乘法,得到一個標量
- 再按照行堆疊起來即可
情形5:兩個參數(shù)至少為一維且至少一個參數(shù)為 N 維(其中 N > 2),則返回批處理矩陣乘法
第一個參數(shù)為N維,第二個參數(shù)為一維時
tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())
(4) 先添加一個維度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再刪除最后一個維度(添加的那個)
得到結果(10 * 3)
tensor1 = torch.randn(10,2, 3, 4) # tensor2 = torch.randn(4) print(torch.matmul(tensor1, tensor2).size())
(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,刪1】
第一個參數(shù)為一維,第二個參數(shù)為二維時
tensor1 = torch.randn(4) tensor2 = torch.randn(10, 4, 3) print(torch.matmul(tensor1, tensor2).size())
tensor2 中第一個10理解為批次, 10個(4 * 3)
(1 * 4)與每個(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次為10,得到(10,3)
tensor1 = torch.randn(4) tensor2 = torch.randn(10,2, 4, 3) print(torch.matmul(tensor1, tensor2).size())
這里批次理解為[10, 2]即可
tensor1 = torch.randn(4) tensor2 = torch.randn(10,4, 2,4,1) print(torch.matmul(tensor1, tensor2).size())
個人理解:當一個參數(shù)為一維時,它要去匹配另一個參數(shù)的最后兩個維度(二維 * 二維)
比如上面的例子就是(1 * 4) 匹配 (4,1), 批次為(10,4,2)
高維 * 高維時
注:這不太好理解 … 感覺就是要找準批次,再進行乘法(靠感覺了 哈哈 離譜)
參考 https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
到此這篇關于深入理解Pytorch中的torch. matmul()的文章就介紹到這了,更多相關Pytorch torch. matmul()內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
解決python3.6用cx_Oracle庫連接Oracle的問題
這篇文章主要介紹了解決python3.6用cx_Oracle庫連接Oracle的問題,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-12-12利用python操作SQLite數(shù)據(jù)庫及文件操作詳解
這篇文章主要給大家介紹了關于利用python操作SQLite數(shù)據(jù)庫及文件操作的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧。2017-09-09keras實現(xiàn)VGG16 CIFAR10數(shù)據(jù)集方式
這篇文章主要介紹了keras實現(xiàn)VGG16 CIFAR10數(shù)據(jù)集方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-07-07