PyTorch張量操作指南(cat、stack、split與chunk)
在深度學(xué)習(xí)實踐中,張量的維度變換是數(shù)據(jù)處理和模型構(gòu)建的基礎(chǔ)技能。無論是多模態(tài)數(shù)據(jù)的融合(如圖像與文本),還是批處理數(shù)據(jù)的拆分重組,合理運用張量操作函數(shù)可顯著優(yōu)化計算流程。PyTorch提供的cat、stack、split和chunk正是解決此類問題的利器。以下將逐一解析其原理與應(yīng)用。
一、torch.cat: 沿指定維度拼接張量
功能描述
torch.cat
(concatenate)沿已有的某一維度連接多個形狀兼容的張量,生成更高維度的單一張量。要求除拼接維度外,其余維度的大小必須完全一致。
示例代碼
import torch a = torch.tensor([[1, 2], [3, 4]]) # 形狀 (2, 2) b = torch.tensor([[5, 6], [7, 8]]) # 在第0維拼接(垂直方向) c = torch.cat([a, b], dim=0) print(c) # 輸出: # tensor([[1, 2], # [3, 4], # [5, 6], # [7, 8]]) # 在第1維拼接(水平方向) d = torch.cat([a, b], dim=1) print(d) # 輸出: # tensor([[1, 2, 5, 6], # [3, 4, 7, 8]])
二、torch.stack: 創(chuàng)建新維度堆疊張量
功能描述
torch.stack
會將輸入張量沿新創(chuàng)建的維度進行堆疊,所有參與堆疊的張量必須具有完全相同的形狀。輸出張量的維度比原張量多一維。
示例代碼
a = torch.tensor([1, 2, 3]) b = torch.tensor([4, 5, 6]) # 沿第0維堆疊,生成二維張量 c = torch.stack([a, b], dim=0) print(c.shape) # torch.Size([2, 3]) print(c) # 輸出: # tensor([[1, 2, 3], # [4, 5, 6]]) # 沿第1維堆疊,生成二維張量 d = torch.stack([a, b], dim=1) print(d.shape) # torch.Size([3, 2]) print(d) # 輸出: # tensor([[1, 4], # [2, 5], # [3, 6]])
三、torch.split: 按尺寸分割張量
功能描述
torch.split
根據(jù)指定的尺寸將輸入張量分割為多個子張量。支持兩種參數(shù)形式:
- 整數(shù)列表:每個元素表示對應(yīng)分片的長度
- 整數(shù)N:等分為N個子張量(需總長度可被整除)
示例代碼
a = torch.arange(9) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) # 按列表尺寸分割 [2,3,4] parts = torch.split(a, [2, 3, 4], dim=0) for part in parts: print(part) ''' 輸出: tensor([0, 1]) tensor([2, 3, 4]) tensor([5, 6, 7, 8]) ''' # 平均分割為3份 chunks = torch.split(a, 3, dim=0) print([c.shape for c in chunks]) # [torch.Size([3]), torch.Size([3]), torch.Size([3])]
四、torch.chunk: 按數(shù)量均分張量
功能描述
torch.chunk
將輸入張量沿指定維度均勻劃分為N份。若無法整除,剩余元素分配到前面的分片中。
示例代碼
a = torch.arange(10) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) # 分成3份,默認在第0維操作 chunks = torch.chunk(a, chunks=3, dim=0) for i, chunk in enumerate(chunks): print(f"Chunk {i}: {chunk}") ''' 輸出: Chunk 0: tensor([0, 1, 2, 3]) Chunk 1: tensor([4, 5, 6]) Chunk 2: tensor([7, 8, 9]) ''' # 在第1維分割二維張量 b = a.reshape(2,5) chunks = torch.chunk(b, chunks=2, dim=1) print(chunks[0].shape) # torch.Size([2, 2]) print(chunks[1].shape) # torch.Size([2, 3])
綜合示例:圖像數(shù)據(jù)的分割與合并處理
以下是結(jié)合圖像數(shù)據(jù)的完整操作示例,模擬圖像預(yù)處理流程中的張量操作場景:
場景設(shè)定
假設(shè)我們有一批RGB圖像數(shù)據(jù)(尺寸為 3×256×256
),需要完成以下操作:
- 將圖像拆分為RGB三個通道
- 對每個通道進行獨立歸一化
- 合并處理后的通道
- 將多張圖像堆疊成批次
- 分割批次為訓(xùn)練/驗證集
代碼實現(xiàn)
import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt # 1. 加載示例圖像 (H, W, C) -> 轉(zhuǎn)換為 (C, H, W) image = Image.open('cat.jpg').convert('RGB') image = transforms.ToTensor()(image) # shape: torch.Size([3, 256, 256]) # 2. 使用split分離RGB通道 r_channel, g_channel, b_channel = torch.split(image, split_size_or_sections=1, dim=0) ''' 可視化原始通道 plt.figure(figsize=(12,4)) plt.subplot(131), plt.imshow(r_channel.squeeze().numpy(), cmap='Reds'), plt.title('Red') plt.subplot(132), plt.imshow(g_channel.squeeze().numpy(), cmap='Greens'), plt.title('Green') plt.subplot(133), plt.imshow(b_channel.squeeze().numpy(), cmap='Blues'), plt.title('Blue') plt.show() ''' # 3. 對每個通道進行歸一化(示例操作) def normalize(tensor): return (tensor - tensor.mean()) / tensor.std() r_norm = normalize(r_channel) g_norm = normalize(g_channel) b_norm = normalize(b_channel) # 4. 使用cat合并處理后的通道 normalized_img = torch.cat([r_norm, g_norm, b_norm], dim=0) '''觀察歸一化效果 plt.imshow(normalized_img.permute(1,2,0)) plt.title('Normalized Image') plt.show() ''' # 5. 創(chuàng)建模擬圖像批次 (假設(shè)有4張相同圖像) batch_images = torch.stack([image]*4, dim=0) # shape: (4, 3, 256, 256) # 6. 使用chunk分割批次為訓(xùn)練集/驗證集 train_set, val_set = torch.chunk(batch_images, chunks=2, dim=0) print(f"Train set size: {train_set.shape}") # torch.Size([2, 3, 256, 256]) print(f"Val set size: {val_set.shape}") # torch.Size([2, 3, 256, 256])
關(guān)鍵操作解析
步驟 | 函數(shù) | 作用 | 維度變化 |
---|---|---|---|
通道分離 | torch.split | 提取單獨顏色通道 | (3,256,256)→3個(1,256,256) |
數(shù)據(jù)合并 | torch.cat | 合并處理后的通道數(shù)據(jù) | 3個(1,256,256)→(3,256,256) |
批次構(gòu)建 | torch.stack | 將單張圖像復(fù)制為4張圖像的批次 | (3,256,256)→(4,3,256,256) |
批次劃分 | torch.chunk | 將批次按比例劃分為訓(xùn)練/驗證集 | (4,3,256,256)→2×(2,3,256,256) |
擴展應(yīng)用建議
- 數(shù)據(jù)增強:對split后的通道進行不同變換(如僅對R通道做對比度調(diào)整)
- 模型輸入:stack后的批次可直接輸入CNN網(wǎng)絡(luò)
- 分布式訓(xùn)練:利用chunk將數(shù)據(jù)分布到多個GPU處理
- 特征可視化:通過split提取中間層特征圖的單個通道進行分析
通過這個完整的圖像處理流程示例,可以清晰看到:
split
+cat
組合常用于特征處理管道stack
+chunk
組合是構(gòu)建批處理系統(tǒng)的關(guān)鍵工具- 這些操作在保持計算效率的同時提供了靈活的數(shù)據(jù)控制能力
總結(jié)與對比
函數(shù) | 核心作用 | 維度變化 | 輸入要求 |
---|---|---|---|
torch.cat | 沿現(xiàn)有維度拼接 | 不變 | 各張量形狀需匹配 |
torch.stack | 新建維度堆疊 | +1維 | 所有張量形狀完全相同 |
torch.split | 按尺寸分割 | 不變 | 需指定分割尺寸或份數(shù) |
torch.chunk | 按數(shù)量均分 | 不變 | 總長度需可分配 |
應(yīng)用建議:
- 當(dāng)需要合并同類數(shù)據(jù)且保留原始維度時用
cat
; - 若需擴展維度以表示批次或通道時用
stack
; - 對序列數(shù)據(jù)分段處理優(yōu)先考慮
split
; - 均勻劃分特征圖或張量時選擇
chunk
。
掌握這些工具后,您將能更靈活地操控張量維度,適應(yīng)復(fù)雜模型的構(gòu)建需求!
到此這篇關(guān)于PyTorch張量操作指南(cat、stack、split與chunk)的文章就介紹到這了,更多相關(guān)PyTorch張量操作內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
PyCharm如何設(shè)置Console控制臺輸出自動換行
這篇文章主要介紹了PyCharm如何設(shè)置Console控制臺輸出自動換行問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-05-05如何用Python對數(shù)學(xué)函數(shù)進行求值、求偏導(dǎo)
這篇文章主要介紹了如何用Python對數(shù)學(xué)函數(shù)進行求值、求偏導(dǎo)問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-05-05Python實戰(zhàn)之基于OpenCV的美顏掛件制作
在本文中,我們將學(xué)習(xí)如何創(chuàng)建有趣的基于Snapchat的增強現(xiàn)實,主要包括兩個實戰(zhàn)項目:在檢測到的人臉上的鼻子和嘴巴之間添加胡子掛件,在檢測到的人臉上添加眼鏡掛件。感興趣的童鞋可以看看哦2021-11-11Django+simpleui實現(xiàn)文件上傳預(yù)覽功能(詳細過程)
該文章詳細介紹了如何在Django框架中實現(xiàn)文件上傳、預(yù)覽和下載功能,并使用SimpleUI美化Django后臺界面,通過創(chuàng)建模型、表單、視圖和配置URL,實現(xiàn)了文件的存儲和管理,同時,文章還提到了配置媒體文件、創(chuàng)建模板以及在生產(chǎn)環(huán)境中的部署注意事項,感興趣的朋友一起看看吧2025-02-02簡單分析Python中用fork()函數(shù)生成的子進程
這篇文章主要介紹了Python中用fork()函數(shù)生成的子進程,分析子進程與父進程的執(zhí)行順序,需要的朋友可以參考下2015-05-05