欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中torch.stack()函數(shù)的深入解析

 更新時(shí)間:2022年08月31日 08:54:36   作者:cv_lhp  
在pytorch中常見的拼接函數(shù)主要是兩個(gè),分別是:stack()和cat(),下面這篇文章主要給大家介紹了關(guān)于Pytorch中torch.stack()函數(shù)的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下

一. torch.stack()函數(shù)解析

1. 函數(shù)說明:

1.1 官網(wǎng)torch.stack(),函數(shù)定義及參數(shù)說明如下圖所示:

函數(shù)定義及參數(shù)說明

1.2 函數(shù)功能

沿一個(gè)新維度對(duì)輸入一系列張量進(jìn)行連接,序列中所有張量應(yīng)為相同形狀,stack 函數(shù)返回的結(jié)果會(huì)新增一個(gè)維度。也即是把多個(gè)2維的張量湊成一個(gè)3維的張量;多個(gè)3維的湊成一個(gè)4維的張量…以此類推,也就是在增加新的維度上面進(jìn)行堆疊。

1.3 參數(shù)列表

  • tensors :為一系列輸入張量,類型為turple和List
  • dim :新增維度的(下標(biāo))位置,當(dāng)dim = -1時(shí)默認(rèn)最后一個(gè)維度;范圍必須介于 0 到輸入張量的維數(shù)之間,默認(rèn)是dim=0,在第0維進(jìn)行連接
  • 返回值:輸出新增維度后的張量

2. 代碼舉例

2.1 dim = 0 : 在第0維進(jìn)行連接,相當(dāng)于在行上進(jìn)行組合(輸入張量為一維,輸出張量為兩維)

import torch
#二維輸入張量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=0)#在第0維進(jìn)行連接,相當(dāng)于在行上進(jìn)行組合(輸入張量為一維,輸出張量為兩維)
print(a)
print(b)
print(c)

輸出結(jié)果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1,  2,  3],
        [11, 22, 33]])

2.2 dim = 1 :在第1維進(jìn)行連接,相當(dāng)于在對(duì)應(yīng)行上面對(duì)列元素進(jìn)行組合(輸入張量為一維,輸出張量為兩維)

import torch
#二維輸入張量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=1)#在第1維進(jìn)行連接,相當(dāng)于在對(duì)應(yīng)行上面對(duì)列元素進(jìn)行組合(輸入張量為一維,輸出張量為兩維)
print(a)
print(b)
print(c)

輸出結(jié)果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
        [ 2, 22],
        [ 3, 33]])

2.3 dim=0:表示在第0維進(jìn)行連接,相當(dāng)于在通道維度上進(jìn)行組合(輸入張量為兩維,輸出張量為三維),注意:此處輸入張量維度為二維,因此dim最大只能為2。

import torch
#二維輸入張量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b],dim=0)#在第0維進(jìn)行連接,相當(dāng)于在通道維度上進(jìn)行組合(輸入張量為兩維,輸出張量為三維)
print(a)
print(b)
print(c)

輸出結(jié)果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

2.4 dim=1:表示在第1維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)通道中每個(gè)行進(jìn)行組合,注意:此處輸入張量維度為二維,因此dim最大只能為2。

import torch
#二維輸入張量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)#在第1維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)通道中每個(gè)行進(jìn)行組合
print(a)
print(b)
print(c)

輸出結(jié)果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [11, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

2.5 dim=2:表示在第2維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)行中每個(gè)列元素進(jìn)行組合,注意:此處輸入張量維度為二維,因此dim最大只能為2。

import torch
#二維輸入張量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)#在第2維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)行中每個(gè)列元素進(jìn)行組合
print(a)
print(b)
print(c)

輸出結(jié)果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1, 11],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

2.6 dim=3:表示在第3維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)行中每個(gè)列元素進(jìn)行組合(輸入維度大小為3維,因此dim=3最后一維始終代表為列),注意:此處輸入張量維度為三維,因此dim最大只能為3。

import torch
#三維輸入張量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 3)#表示在第3維進(jìn)行連接,相當(dāng)于對(duì)相應(yīng)行中每個(gè)列元素進(jìn)行組合(最后一維是第三維,始終代表為列)
print(a)
print(b)
print(c)

輸出結(jié)果如下所示:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
tensor([[[[  1,  11],
          [  2,  22],
          [  3,  33]],

         [[  4,  44],
          [  5,  55],
          [  6,  66]],

         [[  7,  77],
          [  8,  88],
          [  9,  99]]],


        [[[ 10, 110],
          [ 20, 220],
          [ 30, 330]],

         [[ 40, 440],
          [ 50, 550],
          [ 60, 660]],

         [[ 70, 770],
          [ 80, 880],
          [ 90, 990]]]])

2.7 dim=4 (錯(cuò)誤維度:因?yàn)榇颂庉斎霃埩烤S度為三維,所以dim最大只能為3,此處維度為4,因此會(huì)報(bào)錯(cuò))

import torch
#三維輸入張量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 4)
print(a)
print(b)
print(c)

輸出錯(cuò)誤:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)

總結(jié)

到此這篇關(guān)于Pytorch中torch.stack()函數(shù)的文章就介紹到這了,更多相關(guān)Pytorch torch.stack()函數(shù)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評(píng)論