PyTorch 中mm和bmm函數(shù)的使用示例詳解
torch.mm 是 PyTorch 中用于 二維矩陣乘法(matrix-matrix multiplication) 的函數(shù),等價于數(shù)學(xué)中的 A × B 矩陣乘積。
一、函數(shù)定義
torch.mm(input, mat2) → Tensor
執(zhí)行的是兩個 2D Tensor(矩陣)的標(biāo)準(zhǔn)矩陣乘法。
input: 第一個二維張量,形狀為(n × m)mat2: 第二個二維張量,形狀為(m × p)- 返回:形狀為
(n × p)的張量
二、使用條件和注意事項
| 條件 | 說明 |
|---|---|
| 僅支持 2D 張量 | 一維或三維以上使用 torch.matmul 或 @ 操作符 |
| 維度要匹配 | 即 input.shape[1] == mat2.shape[0] |
| 不支持廣播 | 兩個矩陣維度不匹配會直接報錯 |
| 結(jié)果是普通矩陣乘積 | 不是逐元素乘法(Hadamard),即不是 * 或 torch.mul() |
三、示例代碼
示例 1:基本矩陣乘法
import torch A = torch.tensor([[1., 2.], [3., 4.]]) # 2x2 B = torch.tensor([[5., 6.], [7., 8.]]) # 2x2 C = torch.mm(A, B) print(C)
輸出:
tensor([[19., 22.],
[43., 50.]])
計算步驟:
C[0][0] = 1*5 + 2*7 = 19 C[0][1] = 1*6 + 2*8 = 22 ...
示例 2:不匹配維度導(dǎo)致報錯
A = torch.rand(2, 3) B = torch.rand(4, 2) C = torch.mm(A, B) # ? 會報錯
報錯:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)
示例 3:推薦寫法(推薦使用 @ 或 matmul)
A = torch.rand(3, 4) B = torch.rand(4, 5) C1 = torch.mm(A, B) C2 = A @ B # 推薦用法 C3 = torch.matmul(A, B) # 推薦用法
四、與其他乘法函數(shù)的比較
| 函數(shù)名 | 支持維度 | 運算類型 | 支持廣播 |
|---|---|---|---|
torch.mm | 僅限二維 | 矩陣乘法 | ? 不支持 |
torch.matmul | 1D, 2D, ND | 自動判斷點乘 / 矩陣乘 | ? 支持 |
torch.bmm | 批量二維乘法 | 3D Tensor batch × batch | ? 不支持 |
torch.mul | 任意維度 | 元素乘(Hadamard) | ? 支持 |
* 運算符 | 任意維度 | 元素乘 | ? 支持 |
@ 運算符 | ND(推薦用) | 矩陣乘法(和 matmul 一樣) | ? |
五、典型應(yīng)用場景
- 神經(jīng)網(wǎng)絡(luò)權(quán)重乘法:
output = torch.mm(W, x) - 點云 / 圖像變換:
x' = torch.mm(R, x) + t - 多層感知機中的矩陣計算
- 注意力機制中 QK^T 乘積
六、總結(jié):什么時候用 mm?
| 使用場景 | 用什么 |
|---|---|
| 僅二維矩陣乘法 | torch.mm |
| 高維或支持廣播乘法 | torch.matmul / @ |
| 批量矩陣乘法 (如 batch_size×3×3) | torch.bmm |
| 元素乘 | torch.mul or * |
在 PyTorch 中,torch.bmm 是 批量矩陣乘法(batch matrix multiplication) 的操作,專用于處理三維張量(batch of matrices)。它的主要作用是對一組矩陣成對進(jìn)行乘法,效率遠(yuǎn)高于手動循環(huán)計算。
一、torch.bmm 語法
torch.bmm(input, mat2, *, out=None) → Tensor
- input:
Tensor,形狀為(B, N, M) - mat2:
Tensor,形狀為(B, M, P) - 返回結(jié)果形狀為
(B, N, P)
這表示對 B 對 N×M 和 M×P 的矩陣進(jìn)行成對相乘。
二、示例演示
示例 1:基礎(chǔ)用法
import torch # 定義兩個 batch 矩陣 A = torch.randn(4, 2, 3) # shape: (B=4, N=2, M=3) B = torch.randn(4, 3, 5) # shape: (B=4, M=3, P=5) # 批量矩陣乘法 C = torch.bmm(A, B) # shape: (4, 2, 5) print(C.shape) # 輸出: torch.Size([4, 2, 5])
示例 2:手動循環(huán) vs bmm 效率對比
# 慢速手動方式 C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))]) # 等效于 bmm C_bmm = torch.bmm(A, B) print(torch.allclose(C_manual, C_bmm)) # True
三、注意事項
1. 維度必須是三維張量
- 否則會報錯:
RuntimeError: batch1 must be a 3D tensor
你可以通過 .unsqueeze() 手動調(diào)整維度:
a = torch.randn(2, 3) b = torch.randn(3, 4) # 升維 a_batch = a.unsqueeze(0) # (1, 2, 3) b_batch = b.unsqueeze(0) # (1, 3, 4) c = torch.bmm(a_batch, b_batch) # (1, 2, 4)
2. 維度必須滿足矩陣乘法規(guī)則
(B, N, M)×(B, M, P)→(B, N, P)- 若
M不一致會報錯:
RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor
3. bmm 不支持廣播(broadcasting)
- 必須顯式提供相同的 batch size。
- 如果只有一個矩陣固定,可以使用
.expand():
A = torch.randn(1, 2, 3) # 單個矩陣 B = torch.randn(4, 3, 5) # 4 個矩陣 # 擴展 A 以進(jìn)行 batch 乘法 A_expand = A.expand(4, -1, -1) C = torch.bmm(A_expand, B) # (4, 2, 5)
四、在實際應(yīng)用中的例子
在點云變換中:批量乘旋轉(zhuǎn)矩陣
# 假設(shè)有 B 個旋轉(zhuǎn)矩陣和點坐標(biāo) R = torch.randn(B, 3, 3) # 旋轉(zhuǎn)矩陣 points = torch.randn(B, 3, N) # 點云 # 先轉(zhuǎn)置點坐標(biāo)為 (B, N, 3) points_T = points.transpose(1, 2) # (B, N, 3) # 用 bmm 做點變換:每組點乘旋轉(zhuǎn) transformed = torch.bmm(points_T, R.transpose(1, 2)) # (B, N, 3)
五、總結(jié)
| 特性 | torch.bmm |
|---|---|
| 操作對象 | 三維張量(batch of matrices) |
| 核心規(guī)則 | (B, N, M) x (B, M, P) = (B, N, P) |
| 是否支持廣播 | ? 不支持,需要手動 .expand() |
與 matmul 區(qū)別 | matmul 支持更多廣播,bmm 更高效用于純批量矩陣乘法 |
| 應(yīng)用場景 | 批量線性變換、點云配準(zhǔn)、神經(jīng)網(wǎng)絡(luò)前向傳播等 |
到此這篇關(guān)于PyTorch 中mm和bmm函數(shù)的使用詳解的文章就介紹到這了,更多相關(guān)PyTorch mm和bmm函數(shù)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Pycharm中切換pytorch的環(huán)境和配置的教程詳解
這篇文章主要介紹了Pycharm中切換pytorch的環(huán)境和配置,本文給大家介紹的非常詳細(xì),對大家的工作或?qū)W習(xí)具有一定的參考借鑒價值,需要的朋友可以參考下2020-03-03
Python異步編程中asyncio.gather的并發(fā)控制詳解
在Python異步編程生態(tài)中,asyncio.gather是并發(fā)任務(wù)調(diào)度的核心工具,本文將通過實際場景和代碼示例,展示如何結(jié)合信號量機制實現(xiàn)精準(zhǔn)并發(fā)控制,希望對大家有所幫助2025-03-03
pycharm遠(yuǎn)程開發(fā)項目的實現(xiàn)步驟
這篇文章主要介紹了pycharm遠(yuǎn)程開發(fā)項目的實現(xiàn)步驟,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2019-01-01
Python自動化實戰(zhàn)之接口請求的實現(xiàn)
本文為大家重點介紹如何通過 python 編碼來實現(xiàn)我們的接口測試以及通過Pycharm的實際應(yīng)用編寫一個簡單接口測試,感興趣的可以了解一下2022-05-05
Python列表切片操作實例探究(提取復(fù)制反轉(zhuǎn))
在Python中,列表切片是處理列表數(shù)據(jù)非常強大且靈活的方法,本文將全面探討Python中列表切片的多種用法,包括提取子列表、復(fù)制列表、反轉(zhuǎn)列表等操作,結(jié)合豐富的示例代碼進(jìn)行詳細(xì)講解2024-01-01

