欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch各種維度變換函數(shù)總結(jié)

 更新時間:2024年02月21日 11:37:42   作者:kangshitao  
本文對于PyTorch中的各種維度變換的函數(shù)進(jìn)行總結(jié),包括reshape()、view()、resize_()、transpose()、permute()、squeeze()、unsqeeze()、expand()、repeat()函數(shù)的介紹和對比,感興趣的可以了解一下

介紹

本文對于PyTorch中的各種維度變換的函數(shù)進(jìn)行總結(jié),包括reshape()、view()、resize_()transpose()、permute()squeeze()、unsqeeze()expand()、repeat()函數(shù)的介紹和對比。

contiguous

區(qū)分各個維度轉(zhuǎn)換函數(shù)的前提是需要了解contiguous。在PyTorch中,contiguous指的是Tensor底層一維數(shù)組的存儲順序和其元素順序一致。

Tensor是以一維數(shù)組的形式存儲的,C/C++使用行優(yōu)先(按行展開)的方式,Python中的Tensor底層實現(xiàn)使用的是C,因此PyThon中的Tensor也是按行展開存儲的,如果其存儲順序按行優(yōu)先展開的一維數(shù)組元素順序一致,就說這個Tensor是連續(xù)(contiguous)的。

形式化定義:

對于任意的d維張量 t,如果滿足對于所有的 i,第 i 維相鄰元素間隔=第 i + 1 維相鄰元素間隔 × 第 i + 1 維長度的乘積,則 t 是連續(xù)的:

  • stride[i] 表示第 i 維相鄰元素之間間隔的位數(shù),稱為步長,可通過 stride () 方法獲得。
  • size  [i] 表示固定其他維度時,第 i 維的元素數(shù)量,即第 i 維的長度,通過 size () 方法獲得。

Python中的多維張量按照行優(yōu)先展開的方式存儲,訪問矩陣中下一個元素是通過偏移來實現(xiàn)的,這個偏移量稱為步長(stride),比如python中,訪問2 × 3 矩陣的同一行中的相鄰元素,物理結(jié)構(gòu)需要偏移 1 個位置,即步長為 1 ,同一列中的兩個相鄰元素則步長為 3 。

舉例說明:

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>t.stride(),t.stride(0),t.stride(1) # 返回t兩個維度的步長,第0維的步長,第1維的步長
((4,1),4,1)
# 第0維的步長,表示沿著列的兩個相鄰元素,比如‘0'和‘4'兩個元素的步長為4
>>>t.size(1)
4
# 對于i=0,滿足stride[0]=stride[1] * size[1]=1*4=4,那么t是連續(xù)的。

PyTorch提供了兩個關(guān)于contiguous的方法:

  • is_contiguous() : 判斷Tensor是否是連續(xù)的
  • contiguous() : 返回新的Tensor,重新開辟一塊內(nèi)存,并且是連續(xù)的

舉例說明(參考[1]):

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0,  4,  8],
        [ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11]])
>>>t.data_ptr() == t2.data_ptr()  # 返回兩個張量的首元素的內(nèi)存地址
True    	#說明底層數(shù)據(jù)是同一個一維數(shù)組
>>>t.is_contiguous(),t2.is_contiguous()  # t連續(xù),t2不連續(xù)
(True, False)

可以看到,t和t2共享內(nèi)存中的數(shù)據(jù)。如果對t2使用contiguous()方法,會開辟新的內(nèi)存空間:

>>>t3 = t2.contiguous()
>>>t3
tensor([[ 0,  4,  8],
        [ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底層數(shù)據(jù)不是同一個一維數(shù)組
False
>>>t3.is_contiguous()
True

關(guān)于contiguous的更深入的解釋可以參考[1].

view()/reshape()

view()

tensor.view()函數(shù)返回一個和tensor共享底層數(shù)據(jù),但不同形狀的tensor。使用view()函數(shù)的要求是tensor必須是contiguous的。

用法如下:

>>>t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>>t2 = t.view(2,6)
>>>t2
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
>>>t.data_ptr() == t2.data_ptr()	# 二者的底層數(shù)據(jù)是同一個一維數(shù)組
True

reshape()

tensor.reshape()類似于tensor.contigous().view()操作,如果tensor是連續(xù)的,則reshape()操作和view()相同,返回指定形狀、共享底層數(shù)據(jù)的tensor;如果tensor是不連續(xù)的,則會開辟新的內(nèi)存空間,返回指定形狀的tensor,底層數(shù)據(jù)和原來的tensor是獨立的,相當(dāng)于先執(zhí)行contigous(),再執(zhí)行view()

如果不在意底層數(shù)據(jù)是否使用新的內(nèi)存,建議使用reshape()代替view().

resize_()

tensor.resize_()函數(shù),返回指定形狀的tensor,與reshape()view()不同的是,resize_()可以只截取tensor一部分?jǐn)?shù)據(jù),或者是元素個數(shù)大于原tensor也可以,會自動擴(kuò)展新的位置。

resize_()函數(shù)對于tensor的連續(xù)性無要求,且返回的值是共享的底層數(shù)據(jù)(同view()),也就是說只返回了指定形狀的索引,底層數(shù)據(jù)不變的。

transpose()/permute()

permute()transpose()還有t()是PyTorch中的轉(zhuǎn)置函數(shù),其中t()函數(shù)只適用于2維矩陣的轉(zhuǎn)置,是這三個函數(shù)里面最”弱”的。

transpose()

tensor.transpose(),返回tensor的指定維度的轉(zhuǎn)置,底層數(shù)據(jù)共享,與view()/reshape()不同的是,transpose()只能實現(xiàn)維度上的轉(zhuǎn)置,不能任意改變維度大小。

對于維度交換來說,view()/reshape()transpose()有很大的區(qū)別,一定不要混用!混用了以后雖然不會報錯,但是數(shù)據(jù)是亂的,血坑。

reshape()/view()transpose()的區(qū)別在于對于維度改變的方式不同,前者是在存儲順序的基礎(chǔ)上對維度進(jìn)行劃分,也就是說將存儲的一維數(shù)組根據(jù)shape大小重新劃分,而transpose()則是真正意義上的轉(zhuǎn)置,比如二維矩陣的轉(zhuǎn)置。

舉個例子:

>>>t
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
>>> t.transpose(0,1)	# 交換t的前兩個維度,即對t進(jìn)行轉(zhuǎn)置。
tensor([[ 0,  4,  8],
        [ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11]])
>>> a.reshape(4,3)     # 使用reshape()/view()的方法,雖然形狀一樣,但是數(shù)據(jù)排列完全不同
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

permute()

tensor.permute()函數(shù),以view的形式返回矩陣指定維度的轉(zhuǎn)置,和transpose()功能相同。

transpose()不同的是,permute()同時對多個維度進(jìn)行轉(zhuǎn)置,且參數(shù)是期望的維度的順序,而transpose()只能同時對兩個維度轉(zhuǎn)置,即參數(shù)只能是兩個,這兩個參數(shù)沒有順序,只代表了哪兩個維度進(jìn)行轉(zhuǎn)置。

舉個例子:

>>> t				# t的形狀為(2,3,2)
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]]])
>>> t.transpose(0,1)   # 使用transpose()將前兩個維度進(jìn)行轉(zhuǎn)置,返回(3,2,2)
tensor([[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]])
>>> t.permute(1,0,2)   # 使用permute()按照指定的維度序列對t轉(zhuǎn)置,返回(3,2,2)
tensor([[[ 0,  1],
         [ 6,  7]],

        [[ 2,  3],
         [ 8,  9]],

        [[ 4,  5],
         [10, 11]]])

squeeze()/unsqueeze()

squeeze()

tensor.squeeze()返回去除size為1的維度的tensor,默認(rèn)去除所有size=1的維度,也可以指定去除某一個size=1的維度,并返回去除后的結(jié)果。

舉個例子:

>>> t.shape 
torch.Size([3, 1, 4, 1])
>>> t.squeeze().shape  # 去除所有size=1的維度
torch.Size([3, 4])
>>> t.squeeze(1).shape  # 去除第1維
torch.Size([3, 4, 1])
>>> t.squeeze(0).shape  # 如果指定的維度size不等于1,則不執(zhí)行任何操作。
torch.Size([3, 1, 4, 1])

unsqueeze()

tensor.unsqueeze()squeeze()相反,是在tensor插入新的維度,插入的維度size=1,用于維度擴(kuò)展。

舉個例子:

>>> t.shape
torch.Size([3, 1, 4, 1])
>>> t.unsqueeze(1).shape   # 在指定的位置上插入新的維度,size=1
torch.Size([3, 1, 1, 4, 1]) 
>>> t.unsqueeze(-1).shape  # 參數(shù)為-1時表示在最后一維添加新的維度,size=1
torch.Size([3, 1, 4, 1, 1])
>>> t.unsqueeze(4).shape   # 和dim=-1等價
torch.Size([3, 1, 4, 1, 1])

expand()/repeat()

expand()

tensor.expand()的功能是擴(kuò)展tensor中的size為1的維度,且只能擴(kuò)展size=1的維度。以view的形式返回tensor,即不改變原來的tensor,只是以視圖的形式返回數(shù)據(jù)。

舉個例子:

>>> t
tensor([[[0, 1, 2],
         [3, 4, 5]]])
>>> t.shape
torch.Size([1, 2, 3])
>>> t.expand(3,2,3)  # 將第0維擴(kuò)展為3,可見其將第0維復(fù)制了3次
tensor([[[0, 1, 2],
         [3, 4, 5]],

        [[0, 1, 2],
         [3, 4, 5]],

        [[0, 1, 2],
         [3, 4, 5]]])
>>> t.expand(3,-1,-1) # dim=-1表示固定這個維度,效果是一樣的,這樣寫更方便
tensor([[[0, 1, 2],
         [3, 4, 5]],

        [[0, 1, 2],
         [3, 4, 5]],

        [[0, 1, 2],
         [3, 4, 5]]])
>>> t.expand(3,2,3).storage()    # expand不擴(kuò)展新的內(nèi)存空間
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]

repeat()

tensor.repeat()用于維度復(fù)制,可以將size為任意大小的維度復(fù)制為n倍,和expand()不同的是,repeat()會分配新的存儲空間,是真正的復(fù)制數(shù)據(jù)。

舉個例子:

>>> t
tensor([[0, 1, 2],
        [3, 4, 5]])
>>> t.shape
torch.Size([2, 3])
>>> t.repeat(2,3)  # 將兩個維度分別復(fù)制2、3倍
tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5],
        [0, 1, 2, 0, 1, 2, 0, 1, 2],
        [3, 4, 5, 3, 4, 5, 3, 4, 5]])
>>> t.repeat(2,3).storage()   # repeat()是真正的復(fù)制,會分配新的空間
 0
 1
 2
 0
 1
 2
 0
 1
 2
 3
 4
 5
 ......
 3
 4
 5
[torch.LongStorage of size 36]

如果維度size=1的時候,repeat()expand()的作用是一樣的,但是expand()不會分配新的內(nèi)存,所以優(yōu)先使用expand()函數(shù)。

總結(jié)

  • view()/reshape()兩個函數(shù)用于將tensor變換為任意形狀,本質(zhì)是將所有的元素重新分配
  • t()/transpose()/permute()用于維度的轉(zhuǎn)置,轉(zhuǎn)置和reshape()操作是有區(qū)別的,注意區(qū)分。
  • squeeze()/unsqueeze()用于壓縮/擴(kuò)展維度,僅在維度的個數(shù)上去除/添加,且去除/添加的維度size=1。
  • expand()/repeat()用于數(shù)據(jù)的復(fù)制,對一個或多個維度上的數(shù)據(jù)進(jìn)行復(fù)制。
  • 以上提到的函數(shù)僅有兩種會分配新的內(nèi)存空間:reshape()操作處理非連續(xù)的tensor時,返回tensor的copy數(shù)據(jù)會分配新的內(nèi)存;repeat()操作會分配新的內(nèi)存空間。其余的操作都是返回的視圖,底層數(shù)據(jù)是共享的,僅在索引上重新分配。

Reference

1. PyTorch中的contiguous

2. stackoverflow-pytorch-contiguous

3. PyTorch官方文檔

到此這篇關(guān)于Pytorch各種維度變換函數(shù)總結(jié)的文章就介紹到這了,更多相關(guān)Pytorch 維度變換內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評論