Pytorch數(shù)據(jù)類型Tensor張量操作的實(shí)現(xiàn)
本文只簡(jiǎn)單介紹pytorch中的對(duì)于張量的各種操作,主要列舉介紹其大致用法和簡(jiǎn)單demo。后續(xù)更為詳細(xì)的介紹會(huì)進(jìn)行補(bǔ)充…
一.創(chuàng)建張量的方式
1.創(chuàng)建無(wú)初始化張量
- torch.empty(3, 4) 創(chuàng)建未初始化內(nèi)存的張量
2.創(chuàng)建隨機(jī)張量
- x = torch.rand(3, 4) 服從0~1間均勻分布
- x = torch.randn(3, 4) 服從(0,1)的正態(tài)分布
- x = torch.rand_like(y) 以rand方式隨機(jī)創(chuàng)建一個(gè)和y形狀相同的張量
- x = torch.randint(1, 10, [3, 3]) 創(chuàng)建元素介于[1,10)的形狀為(3,3)的隨機(jī)張量
3.創(chuàng)建初值為指定數(shù)值的張量
- x = torch.zeros(3, 4) 生成形狀為(3,4)的初值全為0的張量
- x = torch.full([3, 4], 6) 生成形狀為(3,4)的初值全為6的張量
- x = torch.eye(5, 5) 生成形狀為(5,5)的單位陣
4.從數(shù)據(jù)創(chuàng)建張量
- x = torch.tensor([1, 2, 3, 4, 5, 6]) 接收數(shù)據(jù)
- torch.Tensor(3, 4) 接收tensor的維度
5.生成等差數(shù)列張量
- x = torch.arange(0, 10) 生成[0,10)公差為1的等差數(shù)列張量
- x = torch.arange(0, 10, 3) 生成[0,10)公差為3的等差數(shù)列張量
二.改變張量形狀
view()與reshape()方法功能用法完全一致
通過(guò)傳入改變后每一個(gè)維度的大小來(lái)重塑張量的形狀:
x = x.view(2, 3) x = x.reshape(2, 3)
view和reshape操作的示例:
a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) b = a.reshape(2, 4) c = a.view(2, 4) print(b) print(c)
三.索引
y = x.index_select(0, torch.tensor([0, 2]))
第一個(gè)參數(shù)表示選擇的維度,第二個(gè)參數(shù)以tensor的形式傳入,選擇該維度中的指定索引index
x = torch.tensor([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], ]) y = x.index_select(0, torch.tensor([0, 2])) print(y) y = x.index_select(1, torch.tensor([1, 2])) print(y)
根據(jù)掩碼獲得打平后的指定索引張量:
mask = x.ge(5) y = torch.masked_select(x, mask)
通過(guò)比較運(yùn)算獲得一個(gè)mask索引,然后將mask索引傳入masked_select方法來(lái)獲得打平后的新張量,具體示例如下:
x = torch.tensor([ [1, 2, 0, 2], [3, 6, 1, 9], [-1, 7, -8, 1], ]) mask = x.ge(5) y = torch.masked_select(x, mask) print(y) mask = x.gt(0) y = torch.masked_select(x, mask) print(y) mask = x.lt(1) y = torch.masked_select(x, mask) print(y)
四.維度變換
1.維度增加unsqueeze
unsqueeze操作可以讓張量在指定非負(fù)維度前插入新的維度,在負(fù)維度后插入新的維度,傳入?yún)?shù)n表示指定的維度,即n若大于等于0則在n前插入新的維度,若n小于0則在n后插入新的維度:
x.unsqueeze(n)
假設(shè)原張量x的shape為(4,3,28,28),使用x.unsqueeze(0) 在0維度前插入新的維度后,張量的shape變?yōu)?1,4,3,28,28)。原張量y的shape為(2),使用y.unsqueeze(1)在維度1前插入新的維度后,張量的shape變?yōu)?2,1)。代碼示例如下:
x = torch.randint(1, 10, [4, 3, 28, 28]) print(f"original shape: {x.shape}") x = x.unsqueeze(0) print(f"unsqueezed in dim 0: {x.shape}") print("----------------------------------") y = torch.tensor([3, 4]) print(f"original shape: {y.shape}") m = y.unsqueeze(1) print(f"unsqueezed in dim 1: {m.shape}\n{m}") n = y.unsqueeze(0) print(f"unsqueezed in dim 0: {n.shape}\n{n}")
運(yùn)行結(jié)果:
2.維度擴(kuò)展expand
x.expand(a, b, c, d) 操作將原來(lái)維度擴(kuò)展為(a,b ,c ,d),傳入n個(gè)參數(shù)a,b,c,d…表示維度擴(kuò)展后的形狀,其中當(dāng)傳入的維度上的參數(shù)為-1時(shí)表示該維度保持不變。
x.expand(a, b, c, d)
使用expand只能擴(kuò)張?jiān)瓉?lái)大小為1的維度,該維度擴(kuò)張為n后的張量將在該維度上將數(shù)據(jù)復(fù)制n次,將原shape為(1,3,1)的張量擴(kuò)展為shape為(2,3,4)的張量:
x = torch.randint(0, 2, [1, 3, 1]) y = x.expand(2, 3, 4) print(f"original tensor in dim(1,3,1):\n{x}") print(f"expanded tensor in dim(2,3,4):\n{y}")
運(yùn)行結(jié)果:
3.維度減少squeeze
x.squeeze()操作可以壓縮張量的維度,當(dāng)不傳入任何參數(shù)時(shí),squeeze()操作壓縮所有可以壓縮的維度,當(dāng)傳入指定參數(shù)時(shí),參數(shù)可以是負(fù)數(shù),將壓縮張量的指定維度。
x.squeeze() x.squeeze(n)
x = torch.tensor([1, 2, 3, 4, 5, 6]) y = x.unsqueeze(1).unsqueeze(2).unsqueeze(0) print(f"original shape : {y.shape}") print(f"squeezed in all dim: {y.squeeze().shape}") print(f"squeezed in dim 0: {y.squeeze(0).shape}") print(f"squeezed in dim 1: {y.squeeze(1).shape}")
運(yùn)行結(jié)果:
4.維度擴(kuò)展repeat
x.repeat(a,b,c,d) 在原來(lái)維度上分別拷貝a,b,c,d次
x.repeat(a, b, c, d)
原張量x的shape為(1,2,1),通過(guò)執(zhí)行repeat(2,1,2)操作后shape變?yōu)?2,2,2),再通過(guò)repeat(1,3,5)操作后shape變?yōu)?2,6,10):
x = torch.tensor([1, 2]).reshape(1, 2, 1) y = x.repeat(2, 1, 2) z = y.repeat(1, 3, 5) print(f"original tensor in dim(1,2,1): \n{x}") print(f"repeated tensor in dim(2,2,2): \n{y}") print(f"repeated tensor in dim(2,6,10): \n{z}")
五.維度交換
1.簡(jiǎn)單的二維轉(zhuǎn)置函數(shù)t:
x.t()
2.交換任意兩個(gè)維度transpose
x = torch.randint(1, 10, [2, 4, 3])y = x.transpose(0, 2)print(f"original tensor in shape(2,4,3):\n{<!--{cke_protected}{C}%3C!%2D%2D%20%2D%2D%3E-->x}")print(f"transposed tensor in shape(3,4,2):\n{<!--{cke_protected}{C}%3C!%2D%2D%20%2D%2D%3E-->y}")x = torch.randint(1, 10, [2, 4, 3]) y = x.transpose(0, 2) print(f"original tensor in shape(2,4,3):\n{x}") print(f"transposed tensor in shape(3,4,2):\n{y}")
3.重新排列原來(lái)的維度順序permute
permute操作用于重新排列維度順序,傳入的參數(shù)代表維度的索引,即dim a,dim b…
x.permute(a, b, c, d)
x.permute(1,2,0)的意義是將原來(lái)的1維度放到0維度的位置,將原來(lái)的2維度放到1維度的位置,將原來(lái)的0維度放到2維度的位置,以此重新排列維度順序:
x = torch.tensor([ [ [1, 2, 3, 1], [4, 5, 3, 6], [1, 1, 0, 1] ], [ [7, 8, 9, 1], [0, 2, 0, 3], [6, 5, 1, 8], ] ]) y = x.permute(1, 2, 0) print(f"original shape: {x.shape}") print(f"permuted shape: {y.shape}") print(f"permuted tensor:\n{y}")
六.張量合并
1.cat操作
代碼示例:
torch.cat([a,b], dim=0)
cat()函數(shù)中首先傳入一個(gè)列表[a, b, c…]表示要合并的張量集合,然后傳入一個(gè)維度dim=n,表示將這些張量在維度n上進(jìn)行合并操作。
注意concat操作合并的維度上兩個(gè)張量的維度大小可以不同,但是其余維度上必須具有相同的大小,例如(3,4,5)可以和(2,4,5)在0維度上concat合并為(5,4,5)。但是不能在1維度上合并,因?yàn)?維度上兩個(gè)張量的維度大小不同,分別為3和2。
x = torch.tensor([ [ [1, 2, 3, 1], [4, 5, 3, 6], [1, 1, 0, 1] ], [ [7, 8, 9, 1], [0, 2, 0, 3], [6, 5, 1, 8], ] ]) y = x.permute(1, 2, 0) print(f"original shape: {x.shape}") print(f"permuted shape: {y.shape}") print(f"permuted tensor:\n{y}")
運(yùn)行結(jié)果:
2.stack操作
stack操作在合并維度處創(chuàng)建一個(gè)新的維度。
代碼示例:
torch.stack([a, b], dim=0)
tensorA = torch.tensor([ [1, 2, 3], [4, 5, 6] ]) tensorB = torch.tensor([ [7, 8, 9], [3, 2, 1] ]) print(f"tensorA.shape:{tensorA.shape}") print(f"tensorB.shape:{tensorB.shape}") print("try to stack A with B in dim0:") tensorC = torch.stack([tensorA, tensorB], dim=0) print(f"tensorC.shape:{tensorC.shape}\n{tensorC}\n--------------------------") print("try to stack A with B in dim1:") tensorC = torch.stack([tensorA, tensorB], dim=1) print(f"tensorC.shape:{tensorC.shape}\n{tensorC}\n--------------------------") print("try to stack A with B in dim2:") tensorC = torch.stack([tensorA, tensorB], dim=2) print(f"tensorC.shape:{tensorC.shape}\n{tensorC}\n--------------------------") print("try to stack A with B in dim3:") tensorC = torch.stack([tensorA, tensorB], dim=3) print(f"tensorC.shape:{tensorC.shape}") print(tensorC)
運(yùn)行結(jié)果:
七.張量的分割
1.split操作
split操作是對(duì)張量在指定維度上將張量進(jìn)行分割,可以按給定長(zhǎng)度等分,也可以通過(guò)列表傳入分割方法。下面兩種分割方式結(jié)果是相同的,第一種方式是將張量x在維度0上按照每一份長(zhǎng)度為1進(jìn)行等分;第二種方式是按照長(zhǎng)度[1, 1, 1]的模式將張量x分成三份。
a, b, c = x.split(1, dim=0) a, b, c = x.split([1, 1, 1], dim=0)
x = torch.tensor([ [ [1, 2, 1, 3], [0, 1, 2, 1], [9, 8, 1, 2] ], [ [1, 2, 1, 2], [4, 2, 4, 4], [1, 0, 0, 0] ], [ [3, 3, 3, 1], [1, 0, 2, 3], [5, 1, 2, 5] ] ]) print(x.shape) a, b, c = x.split(1, dim=0) print(f"a.shape:{a.shape}\nb.shape:{b.shape}\nc.shape:{c.shape}") print("------------------------------------") a, b = x.split([1, 2], dim=0) print(f"a.shape:{a.shape}\nb.shape:{b.shape}")
2.chunk操作
chunk操作是對(duì)張量的某一維度按數(shù)量進(jìn)行分割,首先傳入第一個(gè)參數(shù)代表要分割成的份數(shù),第二個(gè)參數(shù)指定了在哪一個(gè)維度上分割,下面的API樣例代表將張量在維度0上分割為3個(gè)張:
a, b, c = x.chunk(3, dim=0)
對(duì)上例split中的張量x用chunk做分割的示例如下:
a, b, c = x.chunk(3, dim=1) print(a.shape) print(b.shape) print(c.shape) print("---------------------") a, b = x.chunk(2, dim=2) print(a.shape) print(b.shape)
到此這篇關(guān)于Pytorch數(shù)據(jù)類型Tensor張量操作的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)Pytorch數(shù)據(jù)類型Tensor張量操作的實(shí)現(xiàn)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python防止程序超時(shí)的實(shí)現(xiàn)示例
因?yàn)槟硞€(gè)需求,需要在程序運(yùn)行的時(shí)候防止超時(shí),本文主要介紹了python防止程序超時(shí)的實(shí)現(xiàn)示例,具有一定的參考價(jià)值,感興趣的可以了解一下2023-08-08講解Python中的標(biāo)識(shí)運(yùn)算符
這篇文章主要介紹了講解Python中的標(biāo)識(shí)運(yùn)算符,是Python學(xué)習(xí)當(dāng)中的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-05-05最大K個(gè)數(shù)問(wèn)題的Python版解法總結(jié)
這篇文章主要介紹了最大K個(gè)數(shù)問(wèn)題的Python版解法總結(jié),以最大K個(gè)數(shù)問(wèn)題為基礎(chǔ)的算法題目在面試和各大考試及競(jìng)賽中經(jīng)常出現(xiàn),需要的朋友可以參考下2016-06-06Python tkinter庫(kù)實(shí)現(xiàn)登錄注冊(cè)基本功能
Python自帶了tkinter模塊,實(shí)質(zhì)上是一種流行的面向?qū)ο蟮腉UI工具包 TK 的Python編程接口,提供了快速便利地創(chuàng)建GUI應(yīng)用程序的方法,下面這篇文章主要給大家介紹了關(guān)于tkinter庫(kù)制作一個(gè)簡(jiǎn)單的登錄注冊(cè)小程序,需要的朋友可以參考下2022-12-12詳解安裝mitmproxy以及遇到的坑和簡(jiǎn)單用法
mitmproxy 是一款工具,也可以說(shuō)是 python 的一個(gè)包,在命令行操作的工具。這篇文章主要介紹了詳解安裝mitmproxy以及遇到的坑和簡(jiǎn)單用法,感興趣的小伙伴們可以參考一下2019-01-01