pytorch中部分矩陣乘法和數(shù)組乘法的小結
一、torch.mul
該乘法可簡單理解為矩陣各位相乘,一個常見的例子為向量點乘,源碼定義為torch.mul(input,other,out=None)。其中other可以為一個數(shù)也可以為一個張量,other為數(shù)即張量的數(shù)乘。
該函數(shù)可觸發(fā)廣播機制(broadcast)。只要mat1與other滿足broadcast條件,就可可以進行逐元素相乘 。
tensor1 = 2*torch.ones(1,4) tensor2 = 3*torch.ones(4,1) print(torch.mul(tensor1, tensor2)) #輸出結果為: tensor([[6., 6., 6., 6.], [6., 6., 6., 6.], [6., 6., 6., 6.], [6., 6., 6., 6.]])
# 生成指定張量 c = torch.Tensor([[1, 2, 3], [4, 5 ,6]]) print(c.shape) # 2*3 print(c) # 生成隨機張量 d = torch.randn(2,2,3) print(d) print(d.shape) # 2*2*3 mul = torch.mul(c, d) # c會自動broadcast和d進行匹配 print(mul.shape) # 2*2*3 print(mul)
二、torch.mm
該函數(shù)一般只能用來計算兩個二維矩陣的矩陣乘法,而且不支持broadcast操作。該函數(shù)源碼定義為torch.mm(input,mat2,out=None) ,參數(shù)與返回值均為tensor形式。
a=torch.ones(4,3) b=2*torch.ones(3,2) c=torch.empty(4,2) torch.mm(a,b,out=c) print(torch.mm(a,b)) print( c ) #輸出結果為 tensor([[6., 6.], [6., 6.], [6., 6.], [6., 6.]]) tensor([[6., 6.], [6., 6.], [6., 6.], [6., 6.]])
三、torch.matmul
這個矩陣乘法是在torch.mm的基礎上增加了廣播機制,源碼定義為torch.matmul(input,other,out=None)。
其基本運算規(guī)則如下:
如果兩個參數(shù)都為一維,則等價于torch.mul,需要注意的是:此時的out不接受任何參數(shù)
如果兩個張量都為二維且符合矩陣相乘規(guī)則,或第一個參數(shù)為一維(長度為m,這里等價為大小為1* m),第二個參數(shù)為二維(大小為m* n)則運算等價于torch.mm
如果第一個參數(shù)為二維(大小m* n),第二個參數(shù)為一維(長度為n),這里第二個參數(shù)會進行轉置成為n* 1的列向量,隨后進行矩陣相乘,將得到的結果再進行轉置,最終返回一個大小為1* m的向量
tensor1 = torch.tensor([[1,1,1,1],[2,2,2,2],[3,3,3,3]],dtype=torch.float32) tensor2 = torch.ones(4) print(tensor1.size()) print(tensor2.size()) print(torch.matmul(tensor1, tensor2).shape) #輸出結果為: torch.Size([3, 4]) torch.Size([4]) torch.Size([3])
還有一種情況就是任意一個參數(shù)至少為3維, 當前面的維度相同且最后兩個維度符合二維矩陣運算規(guī)則可進行計算,例如第一參數(shù)的大小為a* b * c * m,第二個參數(shù)的大小為a* b* m* d,則返回一個大小為a* b* c * d的張量,可觸發(fā)廣播機制。
tensor1 = torch.ones(1,4,3,2) tensor2 = torch.ones(2,6) print(torch.matmul(tensor1, tensor2).size()) #輸出結果為: torch.Size([1, 4, 3, 6])
四、三維帶Batch矩陣乘法 torch.bmm()
torch.bmm(bmat1,bmat2), 其中bmat1(B×n×m),bmat2(B×m×d)輸出out的維度是B×n×d,該函數(shù)兩個輸入必須三維矩陣中的第一維要要相同,不支持broadCast操作。
五、torch中tensor數(shù)組的廣播計算
首先定義兩個張量,x的形狀是[1,2,1],y的形狀是[1,2,2]。
當x與y相乘時,由于x.size(2)不等于y.size(2),x會被擴展為[1,2,2]形狀,然后再與張量y進行乘法運算。
x = torch.rand(1,2,1) y = torch.rand(1,2,2)
到此這篇關于pytorch中部分矩陣乘法和數(shù)組乘法的小結的文章就介紹到這了,更多相關pytorch 矩陣乘法和數(shù)組乘法內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python利用PsUtil實現(xiàn)實時監(jiān)控系統(tǒng)狀態(tài)
PSUtil是一個跨平臺的Python庫,用于檢索有關正在運行的進程和系統(tǒng)利用率(CPU,內存,磁盤,網(wǎng)絡,傳感器)的信息。本文就來用PsUtil實現(xiàn)實時監(jiān)控系統(tǒng)狀態(tài),感興趣的可以跟隨小編一起學習一下2023-04-04配置python的編程環(huán)境之Anaconda + VSCode的教程
這篇文章主要介紹了配置python的編程環(huán)境之Anaconda + VSCode的教程,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下2020-03-03在RedHat系Linux上部署Python的Celery框架的教程
這篇文章主要介紹了在RedHat系Linux上部署Python的Celery框架的教程, Celery是一個并行分布框架,擁有良好的I/O性能,需要的朋友可以參考下2015-04-04Python 中獲取數(shù)組的子數(shù)組示例詳解
在 Python 中獲取一個數(shù)組的子數(shù)組時,可以使用切片操作,使用切片操作來獲取一個數(shù)組的一段連續(xù)的子數(shù)組,并且還可以使用一些方便的語法來簡化代碼,這篇文章主要介紹了如何在 Python 中獲取數(shù)組的子數(shù)組,需要的朋友可以參考下2023-05-05