pytorch中常用的乘法運算及相關的運算符(@和*)
前言
這里總結一下pytorch常用的乘法運算以及相關的運算符(@、*)。
總結放前面:
torch.mm : 用于兩個矩陣(不包括向量)的乘法。如維度為(l,m)和(m,n)相乘
torch.bmm : 用于帶batch的三維向量的乘法。如維度為(b,l,m)和(b,m,n)相乘
torch.mul : 用于兩個同維度矩陣的逐像素點相乘(點乘)。如維度為(l,m)和(l,m)相乘
torch.mv : 用于矩陣和向量之間的乘法(矩陣在前,向量在后)。如維度為(l,m)和(m)相乘,結果的維度為(l)。
torch.matmul : 用于兩個張量(后兩維滿足矩陣乘法的維度)相乘或者是矩陣與向量間的乘法,因為其具有廣播機制(broadcasting,自動補充維度)。如維度為(b,l,m)和(b,m,n);(l,m)和(b,m,n);(b,c,l,m)和(b,c,m,n);(l,m)和(m)相乘等?!酒渥饔冒瑃orch.mm、torch.bmm和torch.mv】
@運算符 : 其作用類似于torch.matmul。
*運算符 : 其作用類似于torch.mul。
1、torch.mm
import torch
a = torch.ones(1, 2)
print(a)
b = torch.ones(2, 3)
print(b)
output = torch.mm(a, b)
print(output)
print(output.size())
"""
tensor([[1., 1.]])
tensor([[1., 1., 1.],
[1., 1., 1.]])
tensor([[2., 2., 2.]])
torch.Size([1, 3])
"""
2、torch.bmm
a = torch.randn(2, 1, 2)
print(a)
b = torch.randn(2, 2, 3)
print(b)
output = torch.bmm(a, b)
print(output)
print(output.size())
"""
tensor([[[-0.1187, 0.2110]],
[[ 0.7463, -0.6136]]])
tensor([[[-0.1186, 1.5565, 1.3662],
[ 1.0199, 2.4644, 1.1630]],
[[-1.9483, -1.6258, -0.4654],
[-0.1424, 1.3892, 0.7559]]])
tensor([[[ 0.2293, 0.3352, 0.0832]],
[[-1.3666, -2.0657, -0.8111]]])
torch.Size([2, 1, 3])
"""
3、torch.mul
a = torch.ones(2, 3) * 2
print(a)
b = torch.randn(2, 3)
print(b)
output = torch.mul(a, b)
print(output)
print(output.size())
"""
tensor([[2., 2., 2.],
[2., 2., 2.]])
tensor([[-0.1187, 0.2110, 0.7463],
[-0.6136, -0.1186, 1.5565]])
tensor([[-0.2375, 0.4220, 1.4925],
[-1.2271, -0.2371, 3.1130]])
torch.Size([2, 3])
"""
4、torch.mv
mat = torch.randn(3, 4)
print(mat)
vec = torch.randn(4)
print(vec)
output = torch.mv(mat, vec)
print(output)
print(output.size())
print(torch.mm(mat, vec.unsqueeze(1)).squeeze(1))
"""
tensor([[-0.1187, 0.2110, 0.7463, -0.6136],
[-0.1186, 1.5565, 1.3662, 1.0199],
[ 2.4644, 1.1630, -1.9483, -1.6258]])
tensor([-0.4654, -0.1424, 1.3892, 0.7559])
tensor([ 0.5982, 2.5024, -5.2481])
torch.Size([3])
tensor([ 0.5982, 2.5024, -5.2481])
"""
5、torch.matmul
# 其作用包含torch.mm、torch.bmm和torch.mv。其他類似,不一一舉例。
a = torch.randn(2, 1, 2)
print(a)
b = torch.randn(2, 2, 3)
print(b)
output = torch.bmm(a, b)
print(output)
output1 = torch.matmul(a, b)
print(output1)
print(output1.size())
"""
tensor([[[-0.1187, 0.2110]],
[[ 0.7463, -0.6136]]])
tensor([[[-0.1186, 1.5565, 1.3662],
[ 1.0199, 2.4644, 1.1630]],
[[-1.9483, -1.6258, -0.4654],
[-0.1424, 1.3892, 0.7559]]])
tensor([[[ 0.2293, 0.3352, 0.0832]],
[[-1.3666, -2.0657, -0.8111]]])
tensor([[[ 0.2293, 0.3352, 0.0832]],
[[-1.3666, -2.0657, -0.8111]]])
torch.Size([2, 1, 3])
"""
# 維度為(b,l,m)和(b,m,n);(l,m)和(b,m,n);(b,c,l,m)和(b,c,m,n);(l,m)和(m)等 a = torch.randn(2, 3, 4) b = torch.randn(2, 4, 5) print(torch.matmul(a, b).size()) a = torch.randn(3, 4) b = torch.randn(2, 4, 5) print(torch.matmul(a, b).size()) a = torch.randn(2, 3, 3, 4) b = torch.randn(2, 3, 4, 5) print(torch.matmul(a, b).size()) a = torch.randn(2, 3) b = torch.randn(3) print(torch.matmul(a, b).size()) """ torch.Size([2, 3, 5]) torch.Size([2, 3, 5]) torch.Size([2, 3, 3, 5]) torch.Size([2]) """
6、@運算符
# @運算符:其作用類似于torch.matmul a = torch.randn(2, 3, 4) b = torch.randn(2, 4, 5) print(torch.matmul(a, b).size()) print((a @ b).size()) a = torch.randn(3, 4) b = torch.randn(2, 4, 5) print(torch.matmul(a, b).size()) print((a @ b).size()) a = torch.randn(2, 3, 3, 4) b = torch.randn(2, 3, 4, 5) print(torch.matmul(a, b).size()) print((a @ b).size()) a = torch.randn(2, 3) b = torch.randn(3) print(torch.matmul(a, b).size()) print((a @ b).size()) """ torch.Size([2, 3, 5]) torch.Size([2, 3, 5]) torch.Size([2, 3, 5]) torch.Size([2, 3, 5]) torch.Size([2, 3, 3, 5]) torch.Size([2, 3, 3, 5]) torch.Size([2]) torch.Size([2]) """
7、*運算符
# *運算符:其作用類似于torch.mul
a = torch.ones(2, 3) * 2
print(a)
b = torch.ones(2, 3) * 3
print(b)
output = torch.mul(a, b)
print(output)
print(output.size())
output1 = a * b
print(output1)
print(output1.size())
"""
tensor([[2., 2., 2.],
[2., 2., 2.]])
tensor([[3., 3., 3.],
[3., 3., 3.]])
tensor([[6., 6., 6.],
[6., 6., 6.]])
torch.Size([2, 3])
tensor([[6., 6., 6.],
[6., 6., 6.]])
torch.Size([2, 3])
"""
附:二維矩陣乘法
神經網絡中包含大量的 2D 張量矩陣乘法運算,而使用 torch.matmul 函數(shù)比較復雜,因此 PyTorch 提供了更為簡單方便的 torch.mm(input, other, out = None) 函數(shù)。下表是 torch.matmul 函數(shù)和 torch.mm 函數(shù)的簡單對比。

torch.matmul 函數(shù)支持廣播,主要指的是當參與矩陣乘積運算的兩個張量中其中有一個是 1D 張量,torch.matmul 函數(shù)會將其廣播成 2D 張量參與運算,最后將廣播添加的維度刪除作為最終 torch.matmul 函數(shù)的返回結果。torch.mm 函數(shù)不支持廣播,相對應的輸入的兩個張量必須為 2D。
import torch input = torch.tensor([[1., 2.], [3., 4.]]) other = torch.tensor([[5., 6., 7.], [8., 9., 10.]]) result = torch.mm(input, other) print(result) # tensor([[21., 24., 27.], # [47., 54., 61.]])
總結
到此這篇關于pytorch中常用的乘法運算及相關的運算符(@和*)的文章就介紹到這了,更多相關pytorch常用乘法運算及運算符內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
Python實現(xiàn)抓取網頁生成Excel文件的方法示例
這篇文章主要介紹了Python實現(xiàn)抓取網頁生成Excel文件的方法,涉及PyQuery模塊的使用及Excel文件相關操作技巧,需要的朋友可以參考下2017-08-08
Python中的def __init__( )函數(shù)
這篇文章主要介紹了Python中的def __init__( )函數(shù),文章圍繞主題展開詳細的內容介紹,具有一定的參考價值,需要的朋友可以參考一下2022-09-09

