Pytorch各種維度變換函數(shù)總結(jié)
介紹
本文對于PyTorch中的各種維度變換的函數(shù)進(jìn)行總結(jié),包括reshape()
、view()
、resize_()
、transpose()
、permute()
、squeeze()
、unsqeeze()
、expand()
、repeat()
函數(shù)的介紹和對比。
contiguous
區(qū)分各個維度轉(zhuǎn)換函數(shù)的前提是需要了解contiguous。在PyTorch中,contiguous指的是Tensor底層一維數(shù)組的存儲順序和其元素順序一致。
Tensor是以一維數(shù)組的形式存儲的,C/C++使用行優(yōu)先(按行展開)的方式,Python中的Tensor底層實現(xiàn)使用的是C,因此PyThon中的Tensor也是按行展開存儲的,如果其存儲順序和按行優(yōu)先展開的一維數(shù)組元素順序一致,就說這個Tensor是連續(xù)(contiguous)的。
形式化定義:
對于任意的d維張量 t,如果滿足對于所有的 i,第 i 維相鄰元素間隔=第 i + 1 維相鄰元素間隔 × 第 i + 1 維長度的乘積,則 t 是連續(xù)的:
- stride[i] 表示第 i 維相鄰元素之間間隔的位數(shù),稱為步長,可通過 stride () 方法獲得。
- size [i] 表示固定其他維度時,第 i 維的元素數(shù)量,即第 i 維的長度,通過 size () 方法獲得。
Python中的多維張量按照行優(yōu)先展開的方式存儲,訪問矩陣中下一個元素是通過偏移來實現(xiàn)的,這個偏移量稱為步長(stride),比如python中,訪問2 × 3 矩陣的同一行中的相鄰元素,物理結(jié)構(gòu)需要偏移 1 個位置,即步長為 1 ,同一列中的兩個相鄰元素則步長為 3 。
舉例說明:
>>>t = torch.arange(12).reshape(3,4) >>>t tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) >>>t.stride(),t.stride(0),t.stride(1) # 返回t兩個維度的步長,第0維的步長,第1維的步長 ((4,1),4,1) # 第0維的步長,表示沿著列的兩個相鄰元素,比如‘0'和‘4'兩個元素的步長為4 >>>t.size(1) 4 # 對于i=0,滿足stride[0]=stride[1] * size[1]=1*4=4,那么t是連續(xù)的。
PyTorch提供了兩個關(guān)于contiguous的方法:
is_contiguous()
: 判斷Tensor是否是連續(xù)的contiguous()
: 返回新的Tensor,重新開辟一塊內(nèi)存,并且是連續(xù)的
舉例說明(參考[1]):
>>>t = torch.arange(12).reshape(3,4) >>>t tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) >>>t2 = t.transpose(0,1) >>>t2 tensor([[ 0, 4, 8], [ 1, 5, 9], [ 2, 6, 10], [ 3, 7, 11]]) >>>t.data_ptr() == t2.data_ptr() # 返回兩個張量的首元素的內(nèi)存地址 True #說明底層數(shù)據(jù)是同一個一維數(shù)組 >>>t.is_contiguous(),t2.is_contiguous() # t連續(xù),t2不連續(xù) (True, False)
可以看到,t和t2共享內(nèi)存中的數(shù)據(jù)。如果對t2使用contiguous()
方法,會開辟新的內(nèi)存空間:
>>>t3 = t2.contiguous() >>>t3 tensor([[ 0, 4, 8], [ 1, 5, 9], [ 2, 6, 10], [ 3, 7, 11]]) >>>t3.data_ptr() == t2.data_ptr() # 底層數(shù)據(jù)不是同一個一維數(shù)組 False >>>t3.is_contiguous() True
關(guān)于contiguous的更深入的解釋可以參考[1].
view()/reshape()
view()
tensor.view()函數(shù)返回一個和tensor共享底層數(shù)據(jù),但不同形狀的tensor。使用view()
函數(shù)的要求是tensor必須是contiguous的。
用法如下:
>>>t tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) >>>t2 = t.view(2,6) >>>t2 tensor([[ 0, 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10, 11]]) >>>t.data_ptr() == t2.data_ptr() # 二者的底層數(shù)據(jù)是同一個一維數(shù)組 True
reshape()
tensor.reshape()類似于tensor.contigous().view()
操作,如果tensor是連續(xù)的,則reshape()操作和view()相同,返回指定形狀、共享底層數(shù)據(jù)的tensor;如果tensor是不連續(xù)的,則會開辟新的內(nèi)存空間,返回指定形狀的tensor,底層數(shù)據(jù)和原來的tensor是獨立的,相當(dāng)于先執(zhí)行contigous()
,再執(zhí)行view()
。
如果不在意底層數(shù)據(jù)是否使用新的內(nèi)存,建議使用
reshape()
代替view()
.
resize_()
tensor.resize_()函數(shù),返回指定形狀的tensor,與reshape()
和view()
不同的是,resize_()
可以只截取tensor一部分?jǐn)?shù)據(jù),或者是元素個數(shù)大于原tensor也可以,會自動擴(kuò)展新的位置。
resize_()
函數(shù)對于tensor的連續(xù)性無要求,且返回的值是共享的底層數(shù)據(jù)(同view()
),也就是說只返回了指定形狀的索引,底層數(shù)據(jù)不變的。
transpose()/permute()
permute()
和transpose()
還有t()
是PyTorch中的轉(zhuǎn)置函數(shù),其中t()
函數(shù)只適用于2維矩陣的轉(zhuǎn)置,是這三個函數(shù)里面最”弱”的。
transpose()
tensor.transpose(),返回tensor的指定維度的轉(zhuǎn)置,底層數(shù)據(jù)共享,與view()/reshape()
不同的是,transpose()
只能實現(xiàn)維度上的轉(zhuǎn)置,不能任意改變維度大小。
對于維度交換來說,view()/reshape()
和transpose()
有很大的區(qū)別,一定不要混用!混用了以后雖然不會報錯,但是數(shù)據(jù)是亂的,血坑。
reshape()/view()
和transpose()
的區(qū)別在于對于維度改變的方式不同,前者是在存儲順序的基礎(chǔ)上對維度進(jìn)行劃分,也就是說將存儲的一維數(shù)組根據(jù)shape大小重新劃分,而transpose()
則是真正意義上的轉(zhuǎn)置,比如二維矩陣的轉(zhuǎn)置。
舉個例子:
>>>t tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) >>> t.transpose(0,1) # 交換t的前兩個維度,即對t進(jìn)行轉(zhuǎn)置。 tensor([[ 0, 4, 8], [ 1, 5, 9], [ 2, 6, 10], [ 3, 7, 11]]) >>> a.reshape(4,3) # 使用reshape()/view()的方法,雖然形狀一樣,但是數(shù)據(jù)排列完全不同 tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]])
permute()
tensor.permute()函數(shù),以view的形式返回矩陣指定維度的轉(zhuǎn)置,和transpose()
功能相同。
與transpose()
不同的是,permute()
同時對多個維度進(jìn)行轉(zhuǎn)置,且參數(shù)是期望的維度的順序,而transpose()
只能同時對兩個維度轉(zhuǎn)置,即參數(shù)只能是兩個,這兩個參數(shù)沒有順序,只代表了哪兩個維度進(jìn)行轉(zhuǎn)置。
舉個例子:
>>> t # t的形狀為(2,3,2) tensor([[[ 0, 1], [ 2, 3], [ 4, 5]], [[ 6, 7], [ 8, 9], [10, 11]]]) >>> t.transpose(0,1) # 使用transpose()將前兩個維度進(jìn)行轉(zhuǎn)置,返回(3,2,2) tensor([[[ 0, 1], [ 6, 7]], [[ 2, 3], [ 8, 9]], [[ 4, 5], [10, 11]]]) >>> t.permute(1,0,2) # 使用permute()按照指定的維度序列對t轉(zhuǎn)置,返回(3,2,2) tensor([[[ 0, 1], [ 6, 7]], [[ 2, 3], [ 8, 9]], [[ 4, 5], [10, 11]]])
squeeze()/unsqueeze()
squeeze()
tensor.squeeze()返回去除size為1的維度的tensor,默認(rèn)去除所有size=1的維度,也可以指定去除某一個size=1的維度,并返回去除后的結(jié)果。
舉個例子:
>>> t.shape torch.Size([3, 1, 4, 1]) >>> t.squeeze().shape # 去除所有size=1的維度 torch.Size([3, 4]) >>> t.squeeze(1).shape # 去除第1維 torch.Size([3, 4, 1]) >>> t.squeeze(0).shape # 如果指定的維度size不等于1,則不執(zhí)行任何操作。 torch.Size([3, 1, 4, 1])
unsqueeze()
tensor.unsqueeze()與squeeze()
相反,是在tensor插入新的維度,插入的維度size=1,用于維度擴(kuò)展。
舉個例子:
>>> t.shape torch.Size([3, 1, 4, 1]) >>> t.unsqueeze(1).shape # 在指定的位置上插入新的維度,size=1 torch.Size([3, 1, 1, 4, 1]) >>> t.unsqueeze(-1).shape # 參數(shù)為-1時表示在最后一維添加新的維度,size=1 torch.Size([3, 1, 4, 1, 1]) >>> t.unsqueeze(4).shape # 和dim=-1等價 torch.Size([3, 1, 4, 1, 1])
expand()/repeat()
expand()
tensor.expand()的功能是擴(kuò)展tensor中的size為1的維度,且只能擴(kuò)展size=1的維度。以view的形式返回tensor,即不改變原來的tensor,只是以視圖的形式返回數(shù)據(jù)。
舉個例子:
>>> t tensor([[[0, 1, 2], [3, 4, 5]]]) >>> t.shape torch.Size([1, 2, 3]) >>> t.expand(3,2,3) # 將第0維擴(kuò)展為3,可見其將第0維復(fù)制了3次 tensor([[[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]]]) >>> t.expand(3,-1,-1) # dim=-1表示固定這個維度,效果是一樣的,這樣寫更方便 tensor([[[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]], [[0, 1, 2], [3, 4, 5]]]) >>> t.expand(3,2,3).storage() # expand不擴(kuò)展新的內(nèi)存空間 0 1 2 3 4 5 [torch.LongStorage of size 6]
repeat()
tensor.repeat()用于維度復(fù)制,可以將size為任意大小的維度復(fù)制為n倍,和expand()
不同的是,repeat()
會分配新的存儲空間,是真正的復(fù)制數(shù)據(jù)。
舉個例子:
>>> t tensor([[0, 1, 2], [3, 4, 5]]) >>> t.shape torch.Size([2, 3]) >>> t.repeat(2,3) # 將兩個維度分別復(fù)制2、3倍 tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5, 3, 4, 5], [0, 1, 2, 0, 1, 2, 0, 1, 2], [3, 4, 5, 3, 4, 5, 3, 4, 5]]) >>> t.repeat(2,3).storage() # repeat()是真正的復(fù)制,會分配新的空間 0 1 2 0 1 2 0 1 2 3 4 5 ...... 3 4 5 [torch.LongStorage of size 36]
如果維度size=1的時候,repeat()
和expand()
的作用是一樣的,但是expand()
不會分配新的內(nèi)存,所以優(yōu)先使用expand()
函數(shù)。
總結(jié)
view()/reshape()
兩個函數(shù)用于將tensor變換為任意形狀,本質(zhì)是將所有的元素重新分配。t()/transpose()/permute()
用于維度的轉(zhuǎn)置,轉(zhuǎn)置和reshape()
操作是有區(qū)別的,注意區(qū)分。squeeze()/unsqueeze()
用于壓縮/擴(kuò)展維度,僅在維度的個數(shù)上去除/添加,且去除/添加的維度size=1。expand()/repeat()
用于數(shù)據(jù)的復(fù)制,對一個或多個維度上的數(shù)據(jù)進(jìn)行復(fù)制。- 以上提到的函數(shù)僅有兩種會分配新的內(nèi)存空間:
reshape()
操作處理非連續(xù)的tensor時,返回tensor的copy數(shù)據(jù)會分配新的內(nèi)存;repeat()
操作會分配新的內(nèi)存空間。其余的操作都是返回的視圖,底層數(shù)據(jù)是共享的,僅在索引上重新分配。
Reference
2. stackoverflow-pytorch-contiguous
3. PyTorch官方文檔
到此這篇關(guān)于Pytorch各種維度變換函數(shù)總結(jié)的文章就介紹到這了,更多相關(guān)Pytorch 維度變換內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python flask框架請求體數(shù)據(jù)、文件上傳、請求頭信息獲取方式詳解
這篇文章主要介紹了Python flask框架請求體數(shù)據(jù)、文件上傳、請求頭信息獲取方式詳解,本文通過實例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友參考下吧2024-03-03python list格式數(shù)據(jù)excel導(dǎo)出方法
今天小編就為大家分享一篇python list格式數(shù)據(jù)excel導(dǎo)出方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10python調(diào)用matplotlib模塊繪制柱狀圖
這篇文章主要為大家介紹了python調(diào)用matplotlib模塊繪制柱狀圖,文中示例代碼介紹的非常詳細(xì),具有一定的參考價值,感興趣的小伙伴們可以參考一下2019-10-10Python Tornado 實現(xiàn)SSE服務(wù)端主動推送方案
SSE是Server-Sent Events 的簡稱,是一種服務(wù)器端到客戶端(瀏覽器)的單項消息推送,本文主要探索兩個方面的實踐一個是客戶端發(fā)送請求,服務(wù)端的返回是分多次進(jìn)行傳輸?shù)?直到傳輸完成,這種情況下請求結(jié)束后,考慮關(guān)閉SSE,所以這種連接可以認(rèn)為是暫時的,感興趣的朋友一起看看吧2024-01-01matlab、python中矩陣的互相導(dǎo)入導(dǎo)出方式
這篇文章主要介紹了matlab、python中矩陣的互相導(dǎo)入導(dǎo)出方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-06-06新手該如何學(xué)python怎么學(xué)好python?
怎么學(xué)好python?怎么靈活應(yīng)用python?2008-10-10Python實現(xiàn)監(jiān)控屏幕界面內(nèi)容變化并發(fā)送通知
這篇文章主要為大家詳細(xì)介紹了如何利用Python實現(xiàn)實時監(jiān)控屏幕上的信息是否發(fā)生變化并發(fā)送通知,文中的示例代碼講解詳細(xì),感興趣的可以了解一下2023-04-04Python列表元組字典集合存儲結(jié)構(gòu)詳解
本文詳細(xì)介紹了Python中列表、元組、字典和集合等數(shù)據(jù)結(jié)構(gòu)的定義、操作和用法,包括數(shù)據(jù)類型的相互嵌套、常用操作方法、循環(huán)遍歷等2025-02-02