pytorch中的reshape()、view()、nn.flatten()和flatten()使用
在使用pytorch定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)時(shí),經(jīng)常會(huì)看到類似如下的.view() / flatten()用法,這里對(duì)其用法做出講解與演示。
torch.reshape用法
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()調(diào)用,其作用是在不改變tensor元素?cái)?shù)目的情況下改變tensor的shape。
torch.reshape() 需要兩個(gè)參數(shù),一個(gè)是待被改變的張量tensor,一個(gè)是想要改變的形狀。
torch.reshape(input, shape) → Tensor
- input(Tensor)-要重塑的張量
- shape(python的元組:ints)-新形狀`
案例1.
輸入:
import torch a = torch.tensor([[0,1],[2,3]]) x = torch.reshape(a,(-1,)) print (x) b = torch.arange(4.) Y = torch.reshape(a,(2,2)) print(Y)
結(jié)果:
tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])
torch.view用法
view()的原理很簡單,其實(shí)就是把原先tensor中的數(shù)據(jù)進(jìn)行排列,排成一行,然后根據(jù)所給的view()中的參數(shù)從一行中按順序選擇組成最終的tensor。
view()可以有多個(gè)參數(shù),這取決于你想要得到的是幾維的tensor,一般設(shè)置兩個(gè)參數(shù),也是神經(jīng)網(wǎng)絡(luò)中常用的(一般在全連接之前),代表二維。
view(h,w),h代表行(想要變?yōu)閹仔校?dāng)不知道要變?yōu)閹仔校酪優(yōu)閹琢袝r(shí)可取-1;w代表的是列(想要變?yōu)閹琢校?,?dāng)不知道要變?yōu)閹琢?,但知道要變?yōu)閹仔袝r(shí)可取-1。
一、普通用法(手動(dòng)調(diào)整)
view()相當(dāng)于reshape、resize,重新調(diào)整Tensor的形狀。
案例2.
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(8, 2) a3 = a1.view(2, 8) a4 = a1.view(4, 4) print(a2) print(a3) print(a4)
輸出
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
二、特殊用法:參數(shù)-1(自動(dòng)調(diào)整size)
view中一個(gè)參數(shù)定為-1,代表自動(dòng)調(diào)整這個(gè)維度上的元素個(gè)數(shù),以保證元素的總數(shù)不變。
輸入
import torch a1 = torch.arange(0,16) print(a1)
輸出
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
輸入
a2 = a1.view(-1, 16) a3 = a1.view(-1, 8) a4 = a1.view(-1, 4) a5 = a1.view(-1, 2) a6 = a1.view(4*4, -1) a7 = a1.view(1*4, -1) a8 = a1.view(2*4, -1) print(a2) print(a3) print(a4) print(a5) print(a6) print(a7) print(a8)
輸出
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
torch.nn.Flatten(start_dim=1,end_dim=-1)
start_dim與end_dim分別表示開始的維度和終止的維度,默認(rèn)值為1和-1,其中1表示第一維度,-1表示最后的維度。結(jié)合起來看意思就是從第一維度到最后一個(gè)維度全部給展平為張量。(注意:數(shù)據(jù)的維度是從0開始的,也就是存在第0維度,第一維度并不是真正意義上的第一個(gè))。
因?yàn)槠浔挥迷谏窠?jīng)網(wǎng)絡(luò)中,輸入為一批數(shù)據(jù),第 0 維為batch(輸入數(shù)據(jù)的個(gè)數(shù)),通常要把一個(gè)數(shù)據(jù)拉成一維,而不是將一批數(shù)據(jù)拉為一維。所以torch.nn.Flatten()默認(rèn)從第一維開始平坦化。
使用nn.Flatten(),使用默認(rèn)參數(shù)
官方給出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#開頭的代碼是注釋
整段代碼的意思是:給定一個(gè)維度為(32,1,5,5)的隨機(jī)數(shù)據(jù)。
1.先使用一次nn.Flatten(),使用默認(rèn)參數(shù):
m = nn.Flatten()
也就是說從第一維度展平到最后一個(gè)維度,數(shù)據(jù)的維度是從0開始的,第一維度實(shí)際上是數(shù)據(jù)的第二位置代表的維度,也就是樣例中的1。
因此進(jìn)行展平后的結(jié)果也就是[32,155]→[32,25]
2.接著再使用一次指定參數(shù)的nn.Flatten(),即
m = nn.Flatten(0,2)
也就是說從第0維度展平到第2維度,0~2,對(duì)應(yīng)的也就是前三個(gè)維度。
因此結(jié)果就是[3215,5]→[160,25]
torch.flatten
torch.flatten()函數(shù)經(jīng)常用于寫分類神經(jīng)網(wǎng)絡(luò)的時(shí)候,經(jīng)過最后一個(gè)卷積層之后,一般會(huì)再接一個(gè)自適應(yīng)的池化層,輸出一個(gè)BCHW的向量。
這時(shí)候就需要用到torch.flatten()函數(shù)將這個(gè)向量拉平成一個(gè)Bx的向量(其中,x = CHW),然后送入到FC層中。
語句結(jié)構(gòu)
torch.flatten(input, start_dim=0, end_dim=-1)
input: 一個(gè) tensor,即要被“攤平”的 tensor。
- start_dim: “攤平”的起始維度。
- end_dim: “攤平”的結(jié)束維度。
作用與 torch.nn.flatten 類似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是類,其默認(rèn)開始維度為第 0 維。
例1:
import torch data_pool = torch.randn(2,2,3,3) # 模擬經(jīng)過最后一個(gè)池化層或自適應(yīng)池化層之后的輸出,Batchsize*c*h*w print(data_pool) y=torch.flatten(data_pool,1) print(y)
輸出結(jié)果:
結(jié)果是一個(gè)B*x的向量。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)定時(shí)自動(dòng)關(guān)閉的tkinter窗口方法
今天小編就為大家分享一篇Python實(shí)現(xiàn)定時(shí)自動(dòng)關(guān)閉的tkinter窗口方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-02-02python網(wǎng)頁請求urllib2模塊簡單封裝代碼
這篇文章主要分享一個(gè)python網(wǎng)頁請求模塊urllib2模塊的簡單封裝代碼,有需要的朋友參考下2014-02-02圖解Python中淺拷貝copy()和深拷貝deepcopy()的區(qū)別
這篇文章主要介紹了Python中淺拷貝copy()和深拷貝deepcopy()的區(qū)別,淺拷貝和深拷貝想必大家在學(xué)習(xí)中遇到很多次,這也是面試中常常被問到的問題,本文就帶你詳細(xì)了解一下2023-05-05python常用函數(shù)random()函數(shù)詳解
這篇文章主要介紹了python常用函數(shù)random()函數(shù),本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2023-02-02Python深度學(xué)習(xí)pytorch神經(jīng)網(wǎng)絡(luò)塊的網(wǎng)絡(luò)之VGG
雖然AlexNet證明深層神經(jīng)網(wǎng)絡(luò)卓有成效,但它沒有提供一個(gè)通用的模板來指導(dǎo)后續(xù)的研究人員設(shè)計(jì)新的網(wǎng)絡(luò)。下面,我們將介紹一些常用于設(shè)計(jì)深層神經(jīng)網(wǎng)絡(luò)的啟發(fā)式概念2021-10-10Python內(nèi)存管理器如何實(shí)現(xiàn)池化技術(shù)
Python中的內(nèi)存管理是從三個(gè)方面來進(jìn)行的,一對(duì)象的引用計(jì)數(shù)機(jī)制,二垃圾回收機(jī)制,三內(nèi)存池機(jī)制,下面這篇文章主要給大家介紹了關(guān)于Python內(nèi)存管理器如何實(shí)現(xiàn)池化技術(shù)的相關(guān)資料,需要的朋友可以參考下2022-05-05Python使用eval函數(shù)執(zhí)行動(dòng)態(tài)標(biāo)表達(dá)式過程詳解
這篇文章主要介紹了Python使用eval函數(shù)執(zhí)行動(dòng)態(tài)標(biāo)表達(dá)式過程詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-10-10