Python中的Broadcast機(jī)制
Python Broadcast機(jī)制
最近在用numpy的時候,里面的矩陣和向量之間各種乘法加法搞的我頭昏腦脹,整理下總結(jié)出來的規(guī)則
首先說明array型數(shù)據(jù)結(jié)構(gòu)有兩種類型,一種是一維的向量,比如用np.linspace(1,2,num=2)創(chuàng)建出的對象,shape為(2,);另外一種就是多維的矩陣,如np.zeros(1,2)創(chuàng)建出的對象,其shape為(1,2),這兩種類型是不一樣的。
矩陣之間的矩陣乘法
不必多說,就是按照正常的矩陣乘法規(guī)則來做
(N,M) (M,P) = (N,P)
矩陣之間按元素相乘、相加
這里開始就涉及到廣播(broadcast)的問題了。
其實也比較簡單,兩個矩陣broadcast后的結(jié)果每一維都是兩個矩陣中最大的。
但broadcast必須滿足兩個規(guī)則,即要么相對應(yīng)的維數(shù)相等,要么其中有一個矩陣的維數(shù)是1。
那么問題來了,哪兩個維度是相對應(yīng)的維數(shù)呢?規(guī)則就是將矩陣的shape寫出來,然后按右對齊逐維對比。
通過以上方法,可以得出兩矩陣broadcast結(jié)果的維數(shù),而最后結(jié)果的計算方法就是先將兩個矩陣都broadcast到結(jié)果的維數(shù),然后再按照相同維度的矩陣對應(yīng)元素相乘、相加。
例子如下:
A ? ? ?(4d array): ?8 x 1 x 6 x 1 B ? ? ?(3d array): ? ? ?7 x 1 x 5 Result (4d array): ?8 x 7 x 6 x 5 A ? ? ?(2d array): ?5 x 4 B ? ? ?(1d array): ? ? ?1 Result (2d array): ?5 x 4 A ? ? ?(2d array): ?15 x 3 x 5 B ? ? ?(1d array): ?15 x 1 x 5 Result (2d array): ?15 x 3 x 5
矩陣和向量之間的矩陣乘法
這里也很簡單,規(guī)則是
作左乘數(shù)的向量是行向量,作右乘數(shù)的向量是列向量。
這樣做的好處就是,結(jié)果矩陣一定也是個向量。這個規(guī)則也說明了向量不一定是行向量(雖然print出來看見的是一個行向量)
矩陣和向量之間的按元素乘法、加法
規(guī)則其實和“二”中說的是一樣的,只不過這里要注意的是,向量在這里永遠(yuǎn)當(dāng)作(1,N)來看,也就是是行向量,按照“二”中所說的broadcast的規(guī)則,向量的維度永遠(yuǎn)從右對齊,也就是只有最右邊有數(shù),也就說明和他進(jìn)行broadcast的矩陣,其最低維(也就是最右側(cè)的維度)要么是一維,要么就和向量的維度相同。
舉例子如下:
矩陣 (3d array) ? : 256 x 256 x 3 向量 (1d array) ? : ? ? ? ? ? ? 3 結(jié)果 (3d array) ? : 256 x 256 x 3
python broadcast機(jī)制的模擬實現(xiàn)
tensorflow的算術(shù)操作:mul/add/sub等op都支持broadcast機(jī)制,該機(jī)制支持不同維度的計算,但是在對維度進(jìn)行逆向比較時需要滿足以下要求:
- 1)二者維度相同
- 2)二者維度有一個為1
- 3)如果維度大小不一致,需要用1來對維度小的數(shù)據(jù)進(jìn)行擴(kuò)展,在進(jìn)行上述判斷;
如:a:[256,256,3]、b:[3]這樣的維度,需要先將b擴(kuò)展至與a一致,將b擴(kuò)展至[1,1,3],再對a、b數(shù)據(jù)進(jìn)行mul/add/sub等計算,最后輸出維度[256,256,3]
如果為了實現(xiàn)broadcast,可以進(jìn)行以下操作進(jìn)行模擬:
- 1)對維度大小不一致的數(shù)組進(jìn)行維度擴(kuò)展
- 2)獲取輸出維度,即broadcast的維度
- 3)進(jìn)行數(shù)據(jù)廣播
粗略代碼如下(這里以四維數(shù)據(jù)為例,進(jìn)行擴(kuò)展):
import tensorflow as tf
import numpy as np
if __name__ == "__main__":
input0_shape = [1,1,3,1]
input1_shape = [3]
#維度擴(kuò)展
input_len = len(input0_shape) - len(input1_shape)
for i in range(input_len):
input1_shape.insert(0,1)
print input1_shape
#獲取broadcast shape
broadcast_shape = [0] * len(input0_shape)
for i in range(len(input0_shape)):
broadcast_shape[i] = max(input0_shape[i],input1_shape[i])
print broadcast_shape
data_a = np.random.random(input0_shape) #hwcn
data_b = np.random.random(input1_shape) #h,w,c_out,c_in
a = tf.placeholder("float")
b = tf.placeholder("float")
c = tf.add(a,b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(c, feed_dict={a: data_a,b:data_b})
#print data_a
print data_b
print out.shape
#print out - data_a
res_pre = out - data_b #獲取input0的擴(kuò)展結(jié)果,用于驗證實際值
out_tf = res_pre.reshape(broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3])
data_b_tmp = data_a.reshape(input0_shape[0]*input0_shape[1]*input0_shape[2]*input0_shape[3])
print "out_tf"
print out_tf
f_dets = open("pre_data.dat", "w")
for k in out_tf:
b = float(k)
a = '{:.10f}'.format(b)
f_dets.write(str(a) + '\n')
f_dets.close()
out_res = [0]*broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]
#進(jìn)行數(shù)據(jù)擴(kuò)展
for i in range(broadcast_shape[0]):
for j in range(broadcast_shape[1]):
for k in range(broadcast_shape[2]):
for m in range(broadcast_shape[3]):
tmp_idx0 = i*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3] \
+ j*broadcast_shape[2]*broadcast_shape[3] + k*broadcast_shape[3] + m
ii = 0
jj = 0
kk = 0
mm = 0
if i >= input0_shape[0]:
ii = input0_shape[0] -1
else:
ii = i
if j >= input0_shape[1]:
jj = input0_shape[1] -1
else:
jj = j
if k >= input0_shape[2]:
kk = input0_shape[2] -1
else:
kk = k
if m >= input0_shape[3]:
mm = input0_shape[3] -1
else:
mm = m
tmp_idx1 = ii*input0_shape[1]*input0_shape[2]*input0_shape[3] \
+ jj*input0_shape[2]*input0_shape[3] + kk*input0_shape[3] + mm
#print mm
out_res[tmp_idx0] = data_b_tmp[tmp_idx1]
f_dets = open("aft_data.dat", "w")
for k in out_res:
b = float(k)
a = '{:.10f}'.format(b)
f_dets.write(str(a) + '\n')
f_dets.close()
#對比
print "compare"
print out_res - out_tf總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python讀取含url圖片鏈接的txt文檔方法小結(jié)
這篇文章主要為大家詳細(xì)介紹了三種Python讀取含url圖片鏈接的txt文檔方法,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-04-04
Django 設(shè)置多環(huán)境配置文件載入問題
這篇文章主要介紹了Django 設(shè)置多環(huán)境配置文件載入問題,本文通過實例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價值,需要的朋友可以參考下2020-02-02
python編程培訓(xùn) python培訓(xùn)靠譜嗎
現(xiàn)在大家都知道,比較火的編程語言就是python了,很多朋友都想學(xué)習(xí)python編程,想上一個好的python培訓(xùn)班,小編今天給大家全面分析一下關(guān)于python編程培訓(xùn)方面的問題,希望能給你答疑解惑。2018-01-01

