numpy中tensordot的用法
楔子
在numpy中有一個(gè)tensordot方法,尤其在做機(jī)器學(xué)習(xí)的時(shí)候會(huì)很有用。估計(jì)有人看到這個(gè)名字,會(huì)想到tensorflow,沒(méi)錯(cuò)tensorflow里面也有tensordot這個(gè)函數(shù)。這個(gè)函數(shù)它的作用就是,可以讓兩個(gè)不同維度的數(shù)組進(jìn)行相乘。我們來(lái)舉個(gè)例子:
import numpy as np
a = np.random.randint(0, 9, (3, 4))
b = np.random.randint(0, 9, (4, 5))
try:
print(a * b)
except Exception as e:
print(e) # operands could not be broadcast together with shapes (3,4) (4,5)
# 很明顯,a和b兩個(gè)數(shù)組的維度不一樣,沒(méi)辦法相乘
# 但是
print(np.tensordot(a, b, 1))
"""
[[32 32 28 28 52]
[10 25 40 38 78]
[56 7 28 0 42]]
"""
# 我們看到使用tensordot是可以的
下面我們來(lái)看看這個(gè)函數(shù)的用法
函數(shù)原型
@array_function_dispatch(_tensordot_dispatcher) def tensordot(a, b, axes=2):
我們看到這個(gè)函數(shù)接收三個(gè)參數(shù),前兩個(gè)就是numpy中數(shù)組,最后一個(gè)參數(shù)則是用于指定收縮的軸。它可以接收一個(gè)整型、列表、列表里面嵌套列表,具體代表什么含義我們下面舉例說(shuō)明。
理解axes
axes為整型
如果axes接收的是一個(gè)整型:m,那么表示指定數(shù)組a的后n個(gè)軸和數(shù)組b的前n個(gè)軸分別進(jìn)行內(nèi)積,就是對(duì)應(yīng)位置元素相乘、再整體求和。
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 顯然這兩個(gè)數(shù)組不能直接相乘,但是a和后兩個(gè)軸和b的前兩個(gè)軸是可以直接相乘的 # 因?yàn)樗鼈兌际?4, 5), 最后結(jié)果的shape為(3, 8) print(np.tensordot(a, b, 2).shape) # (3, 8)
而且這個(gè)axes默認(rèn)為2,所以它一般都是針對(duì)三維或者三維以上的數(shù)組
但是為了具體理解,后面我們會(huì)使用一維、二維數(shù)據(jù)具體舉例說(shuō)明?,F(xiàn)在先看axes取不同的值,會(huì)得到什么結(jié)果,先理解一下axes的含義。
import numpy as np
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
try:
print(np.tensordot(a, b, 1).shape)
except Exception as e:
print(e) # shape-mismatch for sum
# 結(jié)果報(bào)錯(cuò)了,很好理解,就是形狀不匹配嘛
# axes指定為1,表示a的后一個(gè)軸和b的前一個(gè)軸進(jìn)行內(nèi)積
# 但是一個(gè)是5一個(gè)是4,元素?zé)o法一一對(duì)應(yīng),所以報(bào)錯(cuò),提示shape-mismatch,形狀不匹配
# 這里我們把數(shù)組b的shape改一下,這樣a的后一個(gè)軸和b的前一個(gè)軸就匹配了,都是5
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((5, 4, 8))
print(np.tensordot(a, b, 1).shape) # (3, 4, 4, 8)
"""
這樣就能夠運(yùn)算了,我們說(shuō)指定收縮的軸,進(jìn)行內(nèi)積運(yùn)算得到的是一個(gè)值
所以這里的(3, 4, 5)和(5, 4, 8)變成了(3, 4, 4, 8)
而上一個(gè)例子是(3, 4, 5)和(4, 5, 8),然后axes=2
因?yàn)閍的后兩個(gè)軸和b的前兩個(gè)軸進(jìn)行內(nèi)積變成了一個(gè)具體的值,所以最終的維度就是(3, 8)
"""
如果axes為0的話,會(huì)有什么結(jié)果
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) print(np.tensordot(a, b, 0).shape) # (3, 4, 5, 4, 5, 8) print(np.tensordot(b, a, 0).shape) # (4, 5, 8, 3, 4, 5) """ np.tensordot(a, b, 0)等價(jià)于將a中的每一個(gè)元素都和b相乘 然后再將原來(lái)a中的對(duì)應(yīng)元素替換掉 """
上面的操作也可以使用愛(ài)因斯坦求和來(lái)實(shí)現(xiàn)
axes=0
import numpy as np
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
c1 = np.tensordot(a, b, 0)
c2 = np.einsum("ijk,xyz->ijkxyz", a, b)
print(c1.shape, c2.shape) # (3, 4, 5, 4, 5, 8) (3, 4, 5, 4, 5, 8)
print(np.all(c1 == c2)) # True
"""
生成的c1和c2是一樣的
"""
c3 = np.tensordot(b, a, 0)
c4 = np.einsum("ijk,xyz->xyzijk", a, b)
print(c3.shape, c4.shape) # (4, 5, 8, 3, 4, 5) (4, 5, 8, 3, 4, 5)
print(np.all(c3 == c4)) # True
"""
生成的c3和c4是一樣的
"""
那么它們的效率之間孰優(yōu)孰劣呢?我們?cè)趈upyter上測(cè)試一下
>>> %timeit c1 = np.tensordot(a, b, 0)
50.5 μs ± 206 ns per loop
>>> %timeit c2 = np.einsum("ijk,xyz->ijkxyz", a, b)
7.29 μs ± 242 ns per loop
可以看到愛(ài)因斯坦求和快了不少
axes=1
import numpy as np
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((5, 4, 8))
c1 = np.tensordot(a, b, 1)
c2 = np.einsum("ijk,kyz->ijyz", a, b)
print(c1.shape, c2.shape) # (3, 4, 4, 8) (3, 4, 4, 8)
print(np.all(c1 == c2)) # True
axes=2
import numpy as np
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
c1 = np.tensordot(a, b, 2)
c2 = np.einsum("ijk,jkz->iz", a, b)
print(c1.shape, c2.shape) # (3, 8) (3, 8)
print(np.all(c1 == c2)) # True
axes為列表
如果axes接收的是一個(gè)列表:[m, n],那么表示讓a的第m+1個(gè)(索引為m)軸和b的第n+1(索引為n)個(gè)軸進(jìn)行內(nèi)積。使用列表的方法最大的好處就是,可以指定任意位置的軸。
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 我們看到a的第二個(gè)維度(或者說(shuō)軸)和b的第一個(gè)維度都是4,所以它們是可以進(jìn)行內(nèi)積的 c1 = np.tensordot(a, b, [1, 0]) # 由于內(nèi)積的結(jié)果是一個(gè)標(biāo)量,所以(3, 4, 5)和(4, 5, 8)在tensordot之后的shape是(3, 5, 5, 8) # 相當(dāng)于把各自的4給扔掉了(因?yàn)樽兂闪藰?biāo)量),然后組合在一起 print(c1.shape) # (3, 5, 5, 8) # 同理a的最后一個(gè)維度和b的第二個(gè)維度也是可以內(nèi)積的 # 最后一個(gè)維度也可以使用-1,等于按照列表的索引來(lái)取對(duì)應(yīng)的維度 c2 = np.tensordot(a, b, [-1, 1]) print(c2.shape) # (3, 4, 4, 8)
上面的操作也可以使用愛(ài)因斯坦求和來(lái)實(shí)現(xiàn)
import numpy as np
a = np.arange(60).reshape((3, 4, 5))
b = np.arange(160).reshape((4, 5, 8))
c1 = np.tensordot(a, b, [1, 0])
c2 = np.einsum("ijk,jyz->ikyz", a, b)
print(c1.shape, c2.shape) # (3, 5, 5, 8) (3, 5, 5, 8)
print(np.all(c1 == c2)) # True
c3 = np.tensordot(a, b, [-1, 1])
c4 = np.einsum("ijk,akz->ijaz", a, b)
print(c3.shape, c4.shape) # (3, 4, 4, 8) (3, 4, 4, 8)
print(np.all(c3 == c4)) # True
axes為列表嵌套列表
如果axes接收的是一個(gè)嵌套列表的列表:[[m], [n]],等于說(shuō)可以選多個(gè)軸
import numpy as np a = np.arange(60).reshape((3, 4, 5)) b = np.arange(160).reshape((4, 5, 8)) # 我們想讓a的后兩個(gè)軸和b的前兩個(gè)軸內(nèi)積 c1 = np.tensordot(a, b, axes=2) c2 = np.tensordot(a, b, [[1,2], [0,1]]) print(c1.shape, c2.shape) # (3, 8) (3, 8) print(np.all(c1 == c2)) # True
但是使用列表進(jìn)行篩選還有一個(gè)好處,就是可以忽略順序
import numpy as np a = np.arange(60).reshape((4, 3, 5)) b = np.arange(160).reshape((4, 5, 8)) # 這個(gè)時(shí)候就無(wú)法給axes傳遞整型了 c3 = np.tensordot(a, b, [[0, 2], [0, 1]]) print(c3.shape) # (3, 8)
此外,使用列表篩選還有一個(gè)強(qiáng)大的功能,就是可以倒著取值
import numpy as np a = np.arange(60).reshape((4, 5, 3)) b = np.arange(160).reshape((5, 4, 8)) # 這個(gè)時(shí)候我們選擇前兩個(gè)軸,但是一個(gè)是(4, 5)一個(gè)是(5, 4),所以無(wú)法相乘 # 因此在選擇的時(shí)候需要倒著篩選: # [[0, 1], [1, 0]]-> (4, 5)和(4, 5) 或者 [[1, 0], [0, 1]] -> (5, 4)和(5, 4) c3 = np.tensordot(a, b, [[0, 1], [1, 0]]) print(c3.shape) # (3, 8)
最后同樣看看如何愛(ài)因斯坦求和來(lái)實(shí)現(xiàn)
import numpy as np
a = np.arange(60).reshape((4, 5, 3))
b = np.arange(160).reshape((4, 5, 8))
c1 = np.tensordot(a, b, [[0, 1], [0, 1]])
c2 = np.einsum("ijk,ijz->kz", a, b)
print(c1.shape, c2.shape) # (3, 8) (3, 8)
print(np.all(c1 == c2)) # True
a = np.arange(60).reshape((4, 5, 3))
b = np.arange(160).reshape((5, 4, 8))
c1 = np.tensordot(a, b, [[0, 1], [1, 0]])
c2 = np.einsum("ijk,jiz->kz", a, b)
print(c1.shape, c2.shape) # (3, 8) (3, 8)
print(np.all(c1 == c2)) # True
a = np.arange(60).reshape((4, 3, 5))
b = np.arange(160).reshape((5, 4, 8))
c1 = np.tensordot(a, b, [[0, 2], [1, 0]])
c2 = np.einsum("ijk,kiz->jz", a, b)
print(c1.shape, c2.shape) # (3, 8) (3, 8)
print(np.all(c1 == c2)) # True
以兩個(gè)一維數(shù)組為例
我們來(lái)通過(guò)打印具體的數(shù)組來(lái)看一下tensordot
import numpy as np
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
print(np.tensordot(a, b, axes=0))
"""
[[ 2 3 4]
[ 4 6 8]
[ 6 9 12]]
"""
print(np.einsum("i,j->ij", a, b))
"""
[[ 2 3 4]
[ 4 6 8]
[ 6 9 12]]
"""
# 我們axes=0,等于是a的每一個(gè)元素和相乘,然后再把原來(lái)a對(duì)應(yīng)的元素替換掉
# 所以是a中的1 2 3分別和b相乘,得到[2 3 4] [4 6 8] [6 9 12]、再替換掉1 2 3
# 所以結(jié)果是[[2 3 4] [4 6 8] [6 9 12]]
如果axes=1呢?
import numpy as np
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
print(np.tensordot(a, b, axes=1)) # 20
"""
選取a的前一個(gè)軸和b的后一個(gè)軸進(jìn)行內(nèi)積
而a和b只有一個(gè)軸,所以結(jié)果是一個(gè)標(biāo)量
"""
print(np.einsum("i,i->", a, b)) # 20
如果axes=2呢?首先我們說(shuō)axes等于一個(gè)整型,表示選取a的后n個(gè)軸,b的前n個(gè)軸,而一維數(shù)組它們只有一個(gè)軸
import numpy as np
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
try:
print(np.tensordot(a, b, axes=2)) # 20
except Exception as e:
print(e) # tuple index out of range
顯然索引越界了。
以一個(gè)一維數(shù)組和一個(gè)二維數(shù)組為例
我們通過(guò)一維數(shù)組和二維數(shù)組進(jìn)行tensordot來(lái)感受一下
axes=0
import numpy as np
a = np.array([1, 2, 3])
b = np.array([[2, 3, 4]])
print(np.tensordot(a, b, 0))
"""
[[[ 2 3 4]]
[[ 4 6 8]]
[[ 6 9 12]]]
"""
print(np.einsum("i,jk->ijk", a, b))
"""
[[[ 2 3 4]]
[[ 4 6 8]]
[[ 6 9 12]]]
"""
# 很好理解,就是1 2 3分別和[[2, 3, 4]]相乘再替換掉 1 2 3
print(np.tensordot(a, b, 0).shape) # (3, 1, 3)
##########################
print(np.tensordot(b, a, 0))
"""
[[[ 2 4 6]
[ 3 6 9]
[ 4 8 12]]]
"""
print(np.einsum("i,jk->jki", a, b))
"""
[[[ 2 4 6]
[ 3 6 9]
[ 4 8 12]]]
"""
# 很好理解,就是2 3 4分別和[1 2 3]相乘再替換掉 2 3 4
print(np.tensordot(b, a, 0).shape) # (1, 3, 3)
axes=1的話呢?
import numpy as np
a = np.array([1, 2, 3])
b = np.array([[2, 3, 4], [4, 5, 6]])
try:
print(np.tensordot(a, b, 1))
except Exception as e:
print(e) # shape-mismatch for sum
# 我們注意到報(bào)錯(cuò)了,因?yàn)閍xes=1,表示取a的后一個(gè)軸和b的前1個(gè)軸
# a的shape是(3, 0),所以它的后一個(gè)軸和前一個(gè)軸對(duì)應(yīng)的數(shù)組長(zhǎng)度都是3
# 但是b的前一個(gè)軸對(duì)應(yīng)的數(shù)組長(zhǎng)度是2,不匹配所以報(bào)錯(cuò)
print(np.tensordot(b, a, 1)) # [20 32]
# 我們看到這個(gè)是可以的,因?yàn)檫@表示b的后一個(gè)軸,數(shù)組長(zhǎng)度為3,是匹配的
# 讓后一個(gè)軸的[2 3 4]、[4 5 6]分別和[1 2 3]進(jìn)行內(nèi)積,最終得到兩個(gè)標(biāo)量
try:
print(np.einsum("i,ij->ij", a, b))
except Exception as e:
print(e)
# operands could not be broadcast together with remapped shapes [original->remapped]: (3,)->(3,newaxis) (2,3)->(2,3)
# 同樣對(duì)于愛(ài)因斯坦求和也是無(wú)法這么做的,我們需要換個(gè)順序
print(np.einsum("i,ji->j", a, b)) # [20 32]
# 或者
print(np.einsum("j,ij->i", a, b)) # [20 32]
axes=2的話呢?
import numpy as np
a = np.array([1, 2, 3])
b = np.array([[2, 3, 4], [4, 5, 6]])
try:
print(np.tensordot(a, b, 2))
except Exception as e:
print(e) # tuple index out of range
# 我們注意到報(bào)錯(cuò)了,因?yàn)閍xes=2,表示取a的后兩個(gè)軸和b的前兩個(gè)軸
# 而a總共才1個(gè)軸,所以報(bào)錯(cuò)了
try:
print(np.tensordot(b, a, 2))
except Exception as e:
print(e) # shape-mismatch for sum
# 我們看到雖然也報(bào)錯(cuò)了,但是不是報(bào)索引越界。
# 因?yàn)樯厦姹硎救的前兩個(gè)軸,雖然a只有一個(gè),但是此時(shí)不會(huì)索引越界,只是就取一個(gè)。如果是取后兩個(gè)就會(huì)越界了
# 此時(shí)b是(2, 3),而a是(3,) 不匹配,可能有人覺(jué)得會(huì)發(fā)生廣播,但在這里不會(huì)
以兩個(gè)二維數(shù)組為例
我們?cè)偻ㄟ^(guò)兩個(gè)二維數(shù)組進(jìn)行tensordot來(lái)感受一下
axes=0
import numpy as np
a = np.array([[1, 2, 3]])
b = np.array([[2, 3, 4], [4, 5, 6]])
# a_shape: (1, 3) b_shape(3, 3)
print(np.tensordot(a, b, 0))
"""
[[[[ 2 3 4]
[ 4 5 6]]
[[ 4 6 8]
[ 8 10 12]]
[[ 6 9 12]
[12 15 18]]]]
"""
print(np.einsum("ij,xy->ijxy", a, b))
"""
[[[[ 2 3 4]
[ 4 5 6]]
[[ 4 6 8]
[ 8 10 12]]
[[ 6 9 12]
[12 15 18]]]]
"""
print(np.tensordot(a, b, 0).shape) # (1, 3, 2, 3)
#############
print(np.tensordot(b, a, 0))
"""
[[[[ 2 4 6]]
[[ 3 6 9]]
[[ 4 8 12]]]
[[[ 4 8 12]]
[[ 5 10 15]]
[[ 6 12 18]]]]
"""
print(np.einsum("ij,xy->xyij", a, b))
"""
[[[[ 2 4 6]]
[[ 3 6 9]]
[[ 4 8 12]]]
[[[ 4 8 12]]
[[ 5 10 15]]
[[ 6 12 18]]]]
"""
print(np.tensordot(b, a, 0).shape) # (2, 3, 1, 3)
axes=1
import numpy as np
a = np.array([[1, 2], [3, 4]])
b = np.array([[2, 3, 4], [4, 5, 6]])
# a_shape: (2, 2) b_shape(2, 3)
print(np.tensordot(a, b, 1))
"""
[[10 13 16]
[22 29 36]]
"""
print(np.einsum("ij,jk->ik", a, b))
"""
[[10 13 16]
[22 29 36]]
"""
# 仔細(xì)的你肯定發(fā)現(xiàn)了,此時(shí)就相當(dāng)于矩陣的點(diǎn)乘
print(a @ b)
"""
[[10 13 16]
[22 29 36]]
"""
axes=2
import numpy as np
a = np.array([[1, 2], [3, 4]])
b = np.array([[2, 3, 4], [4, 5, 6]])
# a_shape: (2, 2) b_shape(2, 3)
# 取后兩個(gè)軸顯然不行,因?yàn)?2, 2)和(2, 3)不匹配
try:
print(np.tensordot(a, b, 2))
except Exception as e:
print(e) # shape-mismatch for sum
a = np.array([[1, 2, 3], [2, 2, 2]])
b = np.array([[2, 3, 4], [4, 5, 6]])
print(np.tensordot(a, b, 2)) # 50
print(np.einsum("ij,ij->", a, b)) # 50
最后看即個(gè)愛(ài)因斯坦求和的例子,感受它和主角tensordot的區(qū)別,當(dāng)然如果不熟悉的愛(ài)因斯坦求和的話可以不用看
import numpy as np
a = np.random.randint(1, 9, (5, 3, 2, 3))
b = np.random.randint(1, 9, (3, 3, 2))
c1 = a @ b # 多維數(shù)組,默認(rèn)是對(duì)最后兩位進(jìn)行點(diǎn)乘
c2 = np.einsum("ijkm,jmn->ijkn", a, b)
print(np.all(c1 == c2)) # True
print(c2.shape) # (5, 3, 2, 2)
print(np.einsum("...km,...mn->...kn", a, b).shape) # (5, 3, 2, 2)
# 但如果是
c3 = np.einsum("ijkm,amn->ijkn", a, b)
print(c3.shape) # (5, 3, 2, 2)
# 由于符號(hào)不一樣,所以即使shape一致,但是兩個(gè)數(shù)組不一樣
print(np.all(c3 == c1)) # False
a = np.random.randint(1, 9, (5, 3, 3, 2))
b = np.random.randint(1, 9, (1, 3, 2))
print(np.einsum("ijmk,jmn->ijkn", a, b).shape) # (5, 3, 2, 2)
print(np.einsum("ijkm,jnm->ijkn", a, b).shape) # (5, 3, 3, 3)到此這篇關(guān)于numpy中tensordot的用法的文章就介紹到這了,更多相關(guān)numpy tensordot內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python編程中flask的簡(jiǎn)介與簡(jiǎn)單使用
今天小編就為大家分享一篇關(guān)于Python編程中flask的簡(jiǎn)介與簡(jiǎn)單使用,小編覺(jué)得內(nèi)容挺不錯(cuò)的,現(xiàn)在分享給大家,具有很好的參考價(jià)值,需要的朋友一起跟隨小編來(lái)看看吧2018-12-12
Python中10個(gè)常用的內(nèi)置函數(shù)詳解
這篇文章主要為大家介紹了Python常用的內(nèi)置函數(shù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助2021-12-12
PyQt5中QSpinBox計(jì)數(shù)器的實(shí)現(xiàn)
這篇文章主要介紹了PyQt5中QSpinBox計(jì)數(shù)器的實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2021-01-01
Pytorch中torch.argmax()函數(shù)使用及說(shuō)明
這篇文章主要介紹了Pytorch中torch.argmax()函數(shù)使用及說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01
Python中列表元素轉(zhuǎn)為數(shù)字的方法分析
這篇文章主要介紹了Python中列表元素轉(zhuǎn)為數(shù)字的方法,結(jié)合實(shí)例形式對(duì)比分析了Python列表操作及數(shù)學(xué)運(yùn)算的相關(guān)技巧,需要的朋友可以參考下2016-06-06
解決keras,val_categorical_accuracy:,0.0000e+00問(wèn)題
這篇文章主要介紹了解決keras,val_categorical_accuracy:,0.0000e+00問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-07-07
Python爬蟲(chóng)抓取手機(jī)APP的傳輸數(shù)據(jù)
大多數(shù)APP里面返回的是json格式數(shù)據(jù),或者一堆加密過(guò)的數(shù)據(jù) 。這里以超級(jí)課程表APP為例,抓取超級(jí)課程表里用戶發(fā)的話題2016-01-01

