欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中的reshape()、view()、nn.flatten()和flatten()使用

 更新時(shí)間:2023年08月02日 10:40:22   作者:夢在黎明破曉時(shí)啊  
這篇文章主要介紹了pytorch中的reshape()、view()、nn.flatten()和flatten()使用,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

在使用pytorch定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)時(shí),經(jīng)常會(huì)看到類似如下的.view() / flatten()用法,這里對(duì)其用法做出講解與演示。

torch.reshape用法

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()調(diào)用,其作用是在不改變tensor元素?cái)?shù)目的情況下改變tensor的shape。

torch.reshape() 需要兩個(gè)參數(shù),一個(gè)是待被改變的張量tensor,一個(gè)是想要改變的形狀。

torch.reshape(input, shape) → Tensor

  • input(Tensor)-要重塑的張量
  • shape(python的元組:ints)-新形狀`

案例1.

輸入:

import torch
a = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)
b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)

結(jié)果:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

torch.view用法

view()的原理很簡單,其實(shí)就是把原先tensor中的數(shù)據(jù)進(jìn)行排列,排成一行,然后根據(jù)所給的view()中的參數(shù)從一行中按順序選擇組成最終的tensor。

view()可以有多個(gè)參數(shù),這取決于你想要得到的是幾維的tensor,一般設(shè)置兩個(gè)參數(shù),也是神經(jīng)網(wǎng)絡(luò)中常用的(一般在全連接之前),代表二維。

view(h,w),h代表行(想要變?yōu)閹仔校?dāng)不知道要變?yōu)閹仔校酪優(yōu)閹琢袝r(shí)可取-1;w代表的是列(想要變?yōu)閹琢校?,?dāng)不知道要變?yōu)閹琢?,但知道要變?yōu)閹仔袝r(shí)可取-1。

一、普通用法(手動(dòng)調(diào)整)

view()相當(dāng)于reshape、resize,重新調(diào)整Tensor的形狀。

案例2.

輸入

import torch
a1 = torch.arange(0,16)
print(a1)

輸出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

輸入

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)
print(a2)
print(a3)
print(a4)

輸出

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

二、特殊用法:參數(shù)-1(自動(dòng)調(diào)整size)

view中一個(gè)參數(shù)定為-1,代表自動(dòng)調(diào)整這個(gè)維度上的元素個(gè)數(shù),以保證元素的總數(shù)不變。

輸入

import torch
a1 = torch.arange(0,16)
print(a1)

輸出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

輸入

a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)
print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

輸出

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

torch.nn.Flatten(start_dim=1,end_dim=-1)

start_dim與end_dim分別表示開始的維度和終止的維度,默認(rèn)值為1和-1,其中1表示第一維度,-1表示最后的維度。結(jié)合起來看意思就是從第一維度到最后一個(gè)維度全部給展平為張量。(注意:數(shù)據(jù)的維度是從0開始的,也就是存在第0維度,第一維度并不是真正意義上的第一個(gè))。

因?yàn)槠浔挥迷谏窠?jīng)網(wǎng)絡(luò)中,輸入為一批數(shù)據(jù),第 0 維為batch(輸入數(shù)據(jù)的個(gè)數(shù)),通常要把一個(gè)數(shù)據(jù)拉成一維,而不是將一批數(shù)據(jù)拉為一維。所以torch.nn.Flatten()默認(rèn)從第一維開始平坦化。

使用nn.Flatten(),使用默認(rèn)參數(shù)

官方給出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#開頭的代碼是注釋

整段代碼的意思是:給定一個(gè)維度為(32,1,5,5)的隨機(jī)數(shù)據(jù)。

1.先使用一次nn.Flatten(),使用默認(rèn)參數(shù):

m = nn.Flatten()

也就是說從第一維度展平到最后一個(gè)維度,數(shù)據(jù)的維度是從0開始的,第一維度實(shí)際上是數(shù)據(jù)的第二位置代表的維度,也就是樣例中的1。

因此進(jìn)行展平后的結(jié)果也就是[32,155]→[32,25]

2.接著再使用一次指定參數(shù)的nn.Flatten(),即

m = nn.Flatten(0,2)

也就是說從第0維度展平到第2維度,0~2,對(duì)應(yīng)的也就是前三個(gè)維度。

因此結(jié)果就是[3215,5]→[160,25]

torch.flatten

torch.flatten()函數(shù)經(jīng)常用于寫分類神經(jīng)網(wǎng)絡(luò)的時(shí)候,經(jīng)過最后一個(gè)卷積層之后,一般會(huì)再接一個(gè)自適應(yīng)的池化層,輸出一個(gè)BCHW的向量。

這時(shí)候就需要用到torch.flatten()函數(shù)將這個(gè)向量拉平成一個(gè)Bx的向量(其中,x = CHW),然后送入到FC層中。

在這里插入圖片描述

語句結(jié)構(gòu)

 torch.flatten(input, start_dim=0, end_dim=-1)

input: 一個(gè) tensor,即要被“攤平”的 tensor。

  • start_dim: “攤平”的起始維度。
  • end_dim: “攤平”的結(jié)束維度。

作用與 torch.nn.flatten 類似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是類,其默認(rèn)開始維度為第 0 維。

例1:

import torch
data_pool = torch.randn(2,2,3,3) # 模擬經(jīng)過最后一個(gè)池化層或自適應(yīng)池化層之后的輸出,Batchsize*c*h*w
print(data_pool)
y=torch.flatten(data_pool,1)
print(y)

輸出結(jié)果:

在這里插入圖片描述

結(jié)果是一個(gè)B*x的向量。

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論