pytorch下的unsqueeze和squeeze的用法說(shuō)明
#squeeze 函數(shù):從數(shù)組的形狀中刪除單維度條目,即把shape中為1的維度去掉
#unsqueeze() 是squeeze()的反向操作,增加一個(gè)維度,該維度維數(shù)為1,可以指定添加的維度。例如unsqueeze(a,1)表示在1這個(gè)維度進(jìn)行添加
import torch a=torch.rand(2,3,1) print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze().size()) #torch.Size([2, 3]) print(a.squeeze(0).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-1).size()) #torch.Size([2, 3]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze(-2).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-3).size()) #torch.Size([2, 3, 1]) print(a.squeeze(1).size()) #torch.Size([2, 3, 1]) print(a.squeeze(2).size()) #torch.Size([2, 3]) print(a.squeeze(3).size()) #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3) print(a.unsqueeze().size()) #TypeError: unsqueeze() missing 1 required positional arguments: "dim" print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(0).size()) #torch.Size([1, 2, 3, 1]) print(a.unsqueeze(1).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(3).size()) #torch.Size([2, 3, 1, 1]) print(torch.unsqueeze(a,3)) b=torch.rand(2,1,3,1) print(b.squeeze().size()) #torch.Size([2, 3])
補(bǔ)充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函數(shù)的總結(jié)
學(xué)習(xí)Bert模型的時(shí)候,需要使用到pytorch來(lái)進(jìn)行tensor的操作,由于對(duì)pytorch和tensor不熟悉,就把pytorch中常用的、有關(guān)tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函數(shù)做一個(gè)總結(jié),加深記憶。
1、unsqueeze()和squeeze()
torch.unsqueeze(input, dim,out=None) → Tensor
unsqueeze()的作用是用來(lái)增加給定tensor的維度的,unsqueeze(dim)就是在維度序號(hào)為dim的地方給tensor增加一維。例如:維度為torch.Size([768])的tensor要怎樣才能變?yōu)閠orch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代碼:
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(0) print(a.shape) #torch.Size([1, 768]) a = a.unsqueeze(2) print(a.shape) #torch.Size([1, 768, 1])
也可以直接使用鏈?zhǔn)骄幊蹋?/p>
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(1).unsqueeze(0) print(a.shape) #torch.Size([1, 768, 1])
tensor經(jīng)過(guò)unsqueeze()處理之后,總數(shù)據(jù)量不變;維度的擴(kuò)展類似于list不變直接在外面加幾層[]括號(hào)。
torch.squeeze(input, dim=None, out=None) → Tensor
squeeze()的作用就是壓縮維度,直接把維度為1的維給去掉。形式上表現(xiàn)為,去掉一層[]括號(hào)。
同時(shí),輸出的張量與原張量共享內(nèi)存,如果改變其中的一個(gè),另一個(gè)也會(huì)改變。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) a=a.squeeze() print(a) print(a.shape) #torch.Size([2, 768])
圖片中的維度信息就不一樣,紅框中的括號(hào)層數(shù)不同。
注意的是:squeeze()只能壓縮維度為1的維;其他大小的維不起作用。
a=torch.randn(2,768) print(a.shape) #torch.Size([2, 768]) a=a.squeeze() print(a.shape) #torch.Size([2, 768])
2、expand()
這個(gè)函數(shù)的作用就是對(duì)指定的維度進(jìn)行數(shù)值大小的改變。只能改變維大小為1的維,否則就會(huì)報(bào)錯(cuò)。不改變的維可以傳入-1或者原來(lái)的數(shù)值。
torch.Tensor.expand(*sizes) → Tensor
返回張量的一個(gè)新視圖,可以將張量的單個(gè)維度擴(kuò)大為更大的尺寸。
a=torch.randn(1,1,3,768) print(a) print(a.shape) #torch.Size([1, 1, 3, 768]) b=a.expand(2,-1,-1,-1) print(b) print(b.shape) #torch.Size([2, 1, 3, 768]) c=a.expand(2,1,3,768) print(c.shape) #torch.Size([2, 1, 3, 768])
可以看到b和c的維度是一樣的
第0維由1變?yōu)?,可以看到就直接把原來(lái)的tensor在該維度上復(fù)制了一下。
3、repeat()
repeat(*sizes)
沿著指定的維度,對(duì)原來(lái)的tensor進(jìn)行數(shù)據(jù)復(fù)制。這個(gè)函數(shù)和expand()還是有點(diǎn)區(qū)別的。expand()只能對(duì)維度為1的維進(jìn)行擴(kuò)大,而repeat()對(duì)所有的維度可以隨意操作。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) b=a.repeat(1,2,1) print(b) print(b.shape) #torch.Size([2, 2, 768]) c=a.repeat(3,3,3) print(c) print(c.shape) #torch.Size([6, 3, 2304])
b表示對(duì)a的對(duì)應(yīng)維度進(jìn)行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])
c表示對(duì)a的對(duì)應(yīng)維度進(jìn)行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])
a:
b
c
4、view()
tensor.view()這個(gè)函數(shù)有點(diǎn)類似reshape的功能,簡(jiǎn)單的理解就是:先把一個(gè)tensor轉(zhuǎn)換成一個(gè)一維的tensor,然后再組合成指定維度的tensor。例如:
word_embedding=torch.randn(16,3,768) print(word_embedding.shape) new_word_embedding=word_embedding.view(8,6,768) print(new_word_embedding.shape)
當(dāng)然這里指定的維度的乘積一定要和原來(lái)的tensor的維度乘積相等,不然會(huì)報(bào)錯(cuò)的。16*3*768=8*6*768
另外當(dāng)我們需要改變一個(gè)tensor的維度的時(shí)候,知道關(guān)鍵的維度,有不想手動(dòng)的去計(jì)算其他的維度值,就可以使用view(-1),pytorch就會(huì)自動(dòng)幫你計(jì)算出來(lái)。
word_embedding=torch.randn(16,3,768) print(word_embedding.shape) new_word_embedding=word_embedding.view(-1) print(new_word_embedding.shape) new_word_embedding=word_embedding.view(1,-1) print(new_word_embedding.shape) new_word_embedding=word_embedding.view(-1,768) print(new_word_embedding.shape)
結(jié)果如下:使用-1以后,就會(huì)自動(dòng)得到其他維度維。
需要特別注意的是:view(-1,-1)這樣的用法就會(huì)出錯(cuò)。也就是說(shuō)view()函數(shù)中只能出現(xiàn)單個(gè)-1。
5、cat()
cat(seq,dim,out=None),表示把兩個(gè)或者多個(gè)tensor拼接起來(lái)。
其中 seq表示要連接的兩個(gè)序列,以元組的形式給出,例如:seq=(a,b), a,b 為兩個(gè)可以連接的序列
dim 表示以哪個(gè)維度連接,dim=0, 橫向連接 dim=1,縱向連接
a=torch.randn(4,3) b=torch.randn(4,3) c=torch.cat((a,b),dim=0)#橫向拼接,增加行 torch.Size([8, 3]) print(c.shape) d=torch.cat((a,b),dim=1)#縱向拼接,增加列 torch.Size([4, 6]) print(d.shape)
還有一種寫法:cat(list,dim,out=None),其中l(wèi)ist中的元素為tensor。
tensors=[] for i in range(10): tensors.append(torch.randn(4,3)) a=torch.cat(tensors,dim=0) #torch.Size([40, 3]) print(a.shape) b=torch.cat(tensors,dim=1) #torch.Size([4, 30]) print(b.shape)
結(jié)果:
torch.Size([40, 3]) torch.Size([4, 30])
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
相關(guān)文章
基于python實(shí)現(xiàn)查詢ip地址來(lái)源
這篇文章主要介紹了基于python實(shí)現(xiàn)查詢ip地址來(lái)源,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06widows下安裝pycurl并利用pycurl請(qǐng)求https地址的方法
今天小編就為大家分享一篇widows下安裝pycurl并利用pycurl請(qǐng)求https地址的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10Python3遠(yuǎn)程監(jiān)控程序的實(shí)現(xiàn)方法
今天小編就為大家分享一篇Python3遠(yuǎn)程監(jiān)控程序的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-07-07PyCharm安裝配置Qt Designer+PyUIC圖文教程
這篇文章主要介紹了PyCharm安裝配置Qt Designer+PyUIC圖文教程,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-05-05一文詳解Python中Reduce函數(shù)輕松解決復(fù)雜數(shù)據(jù)聚合
這篇文章主要為大家介紹了Python中Reduce函數(shù)輕松解決復(fù)雜數(shù)據(jù)聚合示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-08-08推薦技術(shù)人員一款Python開(kāi)源庫(kù)(造數(shù)據(jù)神器)
今天小編給大家推薦一款Python開(kāi)源庫(kù),技術(shù)人必備的造數(shù)據(jù)神器!非常不錯(cuò),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2020-07-07詳解字符串在Python內(nèi)部是如何省內(nèi)存的
這篇文章主要介紹了詳解字符串在Python內(nèi)部是如何省內(nèi)存的,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-02-02C#中使用XPath定位HTML中的img標(biāo)簽的操作示例
隨著互聯(lián)網(wǎng)內(nèi)容的日益豐富,網(wǎng)頁(yè)數(shù)據(jù)的自動(dòng)化處理變得愈發(fā)重要,圖片作為網(wǎng)頁(yè)中的重要組成部分,其獲取和處理在許多應(yīng)用場(chǎng)景中都顯得至關(guān)重要,本文將詳細(xì)介紹如何在 C# 應(yīng)用程序中使用 XPath 定位 HTML 中的 img 標(biāo)簽,并實(shí)現(xiàn)圖片的下載,需要的朋友可以參考下2024-07-07Python+Opencv實(shí)現(xiàn)數(shù)字識(shí)別的示例代碼
這篇文章主要介紹了Python+Opencv實(shí)現(xiàn)數(shù)字識(shí)別的示例代碼,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-03-03