欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中的Broadcasting問題

 更新時間:2023年01月03日 09:53:28   作者:luputo  
這篇文章主要介紹了Pytorch中的Broadcasting問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

Numpy、Pytorch中的broadcasting

寫在前面

自己一直都不清楚numpy、pytorch里面不同維數的向量之間的element wise的計算究竟是按照什么規(guī)則來確認維數匹配和不匹配的情況的,比如

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.]])

上面這種情況就會自動讓a和b的維數匹配,a加到了b的每一行上

>>> b = np.ones((5,4))
>>> a = np.arange(5)
>>> c = a + b
Traceback (most recent call last):
? File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (5,) (5,4)

這種情況就無法匹配,此時我們希望的是a能自動加到b的每一列上,但結果看來好像不行

雖然一直存在這種疑惑,但因為平時遇到的各種運算都比較簡單,遇到這種不是直接匹配的array的加法第一直覺就是去console里面試一試,報錯就換個姿勢再試一試,總歸問題可以快速地解決,但是最近在寫模型的時候,遇到了繞不過去的問題,所以去查了文檔,本文就以解決那個問題為目標,來解釋清楚pytorch(numpy也是一樣)中的broadcasting semantics的問題

問題描述

我有一個數據Tensor,維數是64 × 2048 64\times204864×2048,現在我想通過對這64 6464個2048 20482048維的向量做attention(也就是做一個加權和)來得到一個2048 20482048維的向量,因為模型的需要,我需要用五組不同的權值向量來計算出五個不同的加權結果,也就是我的計算結果應該是一個5 × 2048 5\times 20485×2048維的向量,因為在64 6464個向量上加權,所以一組權值向量是64 6464維,五組就是5 × 64 5\times 645×64維

嘗試解決

現在我手頭上有兩個Tensor,一個是數據Tensor(64 × 2048 64\times 204864×2048)另一個是權值Tensor(5 × 64 5\times 645×64),我GAN!直到我寫到了這里,我才發(fā)現這不是一個矩陣乘法就能解決的問題嘛+_+,當然,我想給自己正名,這里我簡化了一下問題所以才發(fā)現原來這么容易就解決了,而原來我在寫代碼的時候因為還要考慮batch_size等問題才云里霧里不知道咋辦,還好當時沒想出來,所以去查了文檔發(fā)現了新的東西,然后寫文章的時候想到也算是完滿了(不然也不會發(fā)現自己好澇)

以上都是題外話,現在,我們還是考慮用愚蠢的element wise的方法來解決,好在現在有兩種方法可以解決問題,所以我們可以用來相互檢驗一下,element wise的解決方法就是,我希望這5個64維的權值向量分別和這64個2048維的向量進行element wise的乘法,也就是第一個64維權值向量先對64個2048維向量加權得到一個2048維的向量,然后第二個64維權值向量先對64個2048維向量加權得到一個2048維的向量…,以此類推總共五個,最終得到五個64 × 2048 64×204864×2048維的向量,然后求和得到最后的5 × 2048 5×20485×2048維的向量

那么按照平常的習慣,我就去先試試pytorch能不能直接地理解我的想法

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out = att * x
Traceback (most recent call last):
? File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 2

直接乘不行,因為維數是不匹配的,那怎樣的維數才算匹配呢?

BROADCASTING SEMANTICS

以下內容主要來源于自官方文檔

很多pytorch的運算是支持broadcasting semantics的,而簡單來說,如果運算支持broadcast,則參與運算的Tensor會自動進行擴展來使得運算符左右的Tensor維數匹配,而無需人手動地去拷貝其中的某個Tensor,這就類似于我們開頭的那個例子

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.],
? ? ? ?[1., 2., 3., 4., 5.]])

我們無需讓a的維數和b一樣,因為numpy自動幫我們做了

這里的另一個重要的概念是broadcastable,如果兩個Tensor是broadcastable的,那么就可以對他倆使用支持broadcast的運算,比如直接加減乘除

而兩個向量要是broadcast的話,必須滿足以下兩個條件

  • 每個tensor至少是一維的
  • 兩個tensor的維數從后往前,對應的位置要么是相等的,要么其中一個是1,或者不存在

這是官方的例子解釋

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同維數的tensor一定是broadcastable的

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# 不是broadcastable的,因為每個tensor維數至少要是1

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( ?3,1,1)
# 是broadcastable的,因為從后往前看,一定要注意是從后往前看!
# 第一個維度都是1,相等,滿足第二個條件
# 第二個維度其中有一個是1,滿足第二個條件
# 第三個維度都是3,相等,滿足第二個條件
# 第四個維度其中有一個不存在,滿足第二個條件

# 但是
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( ?3,1,1)
# 不是broadcastable的,因為從后往前看第三個維度是不match的 2!=3,且都不是1

如果x和y是broadcastable的,那么結果的tensor的size按照如下的規(guī)則計算

  • 如果兩者的維度不一樣,那么就自動增加1維(也就是unsqueeze)
  • 對于結果的每個維度,它取x和y在那一維上的最大值

官方的例子

>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( ?3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

此外,關于broadcast導致的就地(in-place)操作和梯度運算的兼容性等問題,可以自行參考官方文檔

解決問題

上面我們看到,要想兩個Tensor支持element wise的運算,需要它們是broadcastable的,而要想它們是broadcastable的,就需要它們的維度自后向前逐一匹配,回到我們原來的問題中,我們有兩個Tensor x(64 × 2048) att(5 × 64),為了讓它們broadcastable,我們只需要

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> x.shape
torch.Size([10, 1, 64, 2048])
>>> att.shape
torch.Size([1, 5, 64, 1])
>>> out = x * att
>>> out.shape
torch.Size([10, 5, 64, 2048])

最后我們來驗證兩種方法是否結果相同

>>> import torch
>>> bs = 10?
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out1 = torch.matmul(att,x) ?# 直接矩陣相乘
>>> out.shape
torch.Size([10, 5, 2048])

>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> out2 = x * att ?# element wise的方法
>>> out2 = out2.sum(dim=2)

>>> test = torch.sum((out1-out2)<0.00001) ?# 浮點數有微小的誤差
>>> test
tensor(102400)
>>> out1.numel() ?# 最后表明兩個out向量是相等的
102400

Reference

[1] https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#module-numpy.doc.broadcasting

[2] https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

總結

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

  • python實現linux下抓包并存庫功能

    python實現linux下抓包并存庫功能

    這篇文章主要為大家詳細介紹了python實現linux下抓包并存庫功能,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-07-07
  • 深入了解Python iter() 方法的用法

    深入了解Python iter() 方法的用法

    這篇文章主要介紹了深入了解Python iter() 方法的知識,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2019-07-07
  • python中random.randint和random.randrange的區(qū)別詳解

    python中random.randint和random.randrange的區(qū)別詳解

    這篇文章主要介紹了python中random.randint和random.randrange的區(qū)別詳解,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-09-09
  • Python中流程控制的高級用法盤點

    Python中流程控制的高級用法盤點

    在這篇文章中我們將全面深入地介紹?Python?的控制流程,包括條件語句、循環(huán)結構和異常處理等關鍵部分,尤其會將列表解析、生成器、裝飾器等高級用法一網打盡,快跟隨小編學起來吧
    2023-05-05
  • python beautiful soup庫入門安裝教程

    python beautiful soup庫入門安裝教程

    Beautiful Soup是python的一個庫,最主要的功能是從網頁抓取數據。今天通過本文給大家分享python beautiful soup庫入門教程,需要的朋友參考下吧
    2021-08-08
  • 詳解Python虛擬機是如何實現閉包的

    詳解Python虛擬機是如何實現閉包的

    Python中的閉包是一個強大的概念,允許函數捕獲和訪問其周圍的作用域,即使這些作用域在函數執(zhí)行完畢后也能被訪問,這篇文章將著重討論Python虛擬機是如何實現閉包的,文中有相關的代碼示例供大家參考,具有一定的參考價值,需要的朋友可以參考下
    2023-12-12
  • 使用memory_profiler監(jiān)測python代碼運行時內存消耗方法

    使用memory_profiler監(jiān)測python代碼運行時內存消耗方法

    今天小編就為大家分享一篇使用memory_profiler監(jiān)測python代碼運行時內存消耗方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • Python中字典的setdefault()方法教程

    Python中字典的setdefault()方法教程

    在學習python字典操作方法時,感覺setdefault()方法,比字典的其它基本操作方法更難理解的同學比較多,所以想著總結以下,下面這篇文章主要給大家介紹了Python中字典的setdefault()方法,需要的朋友可以參考借鑒,下面來一起看看吧。
    2017-02-02
  • Django獲取model中的字段名和字段的verbose_name方式

    Django獲取model中的字段名和字段的verbose_name方式

    這篇文章主要介紹了Django獲取model中的字段名和字段的verbose_name方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • python實現在內存中讀寫str和二進制數據代碼

    python實現在內存中讀寫str和二進制數據代碼

    這篇文章主要介紹了python實現在內存中讀寫str和二進制數據代碼,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04

最新評論