Python中的Broadcast機(jī)制
Python Broadcast機(jī)制
最近在用numpy的時(shí)候,里面的矩陣和向量之間各種乘法加法搞的我頭昏腦脹,整理下總結(jié)出來的規(guī)則
首先說明array型數(shù)據(jù)結(jié)構(gòu)有兩種類型,一種是一維的向量,比如用np.linspace(1,2,num=2)創(chuàng)建出的對(duì)象,shape為(2,);另外一種就是多維的矩陣,如np.zeros(1,2)創(chuàng)建出的對(duì)象,其shape為(1,2),這兩種類型是不一樣的。
矩陣之間的矩陣乘法
不必多說,就是按照正常的矩陣乘法規(guī)則來做
(N,M) (M,P) = (N,P)
矩陣之間按元素相乘、相加
這里開始就涉及到廣播(broadcast)的問題了。
其實(shí)也比較簡(jiǎn)單,兩個(gè)矩陣broadcast后的結(jié)果每一維都是兩個(gè)矩陣中最大的。
但broadcast必須滿足兩個(gè)規(guī)則,即要么相對(duì)應(yīng)的維數(shù)相等,要么其中有一個(gè)矩陣的維數(shù)是1。
那么問題來了,哪兩個(gè)維度是相對(duì)應(yīng)的維數(shù)呢?規(guī)則就是將矩陣的shape寫出來,然后按右對(duì)齊逐維對(duì)比。
通過以上方法,可以得出兩矩陣broadcast結(jié)果的維數(shù),而最后結(jié)果的計(jì)算方法就是先將兩個(gè)矩陣都broadcast到結(jié)果的維數(shù),然后再按照相同維度的矩陣對(duì)應(yīng)元素相乘、相加。
例子如下:
A ? ? ?(4d array): ?8 x 1 x 6 x 1 B ? ? ?(3d array): ? ? ?7 x 1 x 5 Result (4d array): ?8 x 7 x 6 x 5 A ? ? ?(2d array): ?5 x 4 B ? ? ?(1d array): ? ? ?1 Result (2d array): ?5 x 4 A ? ? ?(2d array): ?15 x 3 x 5 B ? ? ?(1d array): ?15 x 1 x 5 Result (2d array): ?15 x 3 x 5
矩陣和向量之間的矩陣乘法
這里也很簡(jiǎn)單,規(guī)則是
作左乘數(shù)的向量是行向量,作右乘數(shù)的向量是列向量。
這樣做的好處就是,結(jié)果矩陣一定也是個(gè)向量。這個(gè)規(guī)則也說明了向量不一定是行向量(雖然print出來看見的是一個(gè)行向量)
矩陣和向量之間的按元素乘法、加法
規(guī)則其實(shí)和“二”中說的是一樣的,只不過這里要注意的是,向量在這里永遠(yuǎn)當(dāng)作(1,N)來看,也就是是行向量,按照“二”中所說的broadcast的規(guī)則,向量的維度永遠(yuǎn)從右對(duì)齊,也就是只有最右邊有數(shù),也就說明和他進(jìn)行broadcast的矩陣,其最低維(也就是最右側(cè)的維度)要么是一維,要么就和向量的維度相同。
舉例子如下:
矩陣 (3d array) ? : 256 x 256 x 3 向量 (1d array) ? : ? ? ? ? ? ? 3 結(jié)果 (3d array) ? : 256 x 256 x 3
python broadcast機(jī)制的模擬實(shí)現(xiàn)
tensorflow的算術(shù)操作:mul/add/sub等op都支持broadcast機(jī)制,該機(jī)制支持不同維度的計(jì)算,但是在對(duì)維度進(jìn)行逆向比較時(shí)需要滿足以下要求:
- 1)二者維度相同
- 2)二者維度有一個(gè)為1
- 3)如果維度大小不一致,需要用1來對(duì)維度小的數(shù)據(jù)進(jìn)行擴(kuò)展,在進(jìn)行上述判斷;
如:a:[256,256,3]、b:[3]這樣的維度,需要先將b擴(kuò)展至與a一致,將b擴(kuò)展至[1,1,3],再對(duì)a、b數(shù)據(jù)進(jìn)行mul/add/sub等計(jì)算,最后輸出維度[256,256,3]
如果為了實(shí)現(xiàn)broadcast,可以進(jìn)行以下操作進(jìn)行模擬:
- 1)對(duì)維度大小不一致的數(shù)組進(jìn)行維度擴(kuò)展
- 2)獲取輸出維度,即broadcast的維度
- 3)進(jìn)行數(shù)據(jù)廣播
粗略代碼如下(這里以四維數(shù)據(jù)為例,進(jìn)行擴(kuò)展):
import tensorflow as tf import numpy as np if __name__ == "__main__": input0_shape = [1,1,3,1] input1_shape = [3] #維度擴(kuò)展 input_len = len(input0_shape) - len(input1_shape) for i in range(input_len): input1_shape.insert(0,1) print input1_shape #獲取broadcast shape broadcast_shape = [0] * len(input0_shape) for i in range(len(input0_shape)): broadcast_shape[i] = max(input0_shape[i],input1_shape[i]) print broadcast_shape data_a = np.random.random(input0_shape) #hwcn data_b = np.random.random(input1_shape) #h,w,c_out,c_in a = tf.placeholder("float") b = tf.placeholder("float") c = tf.add(a,b) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) out = sess.run(c, feed_dict={a: data_a,b:data_b}) #print data_a print data_b print out.shape #print out - data_a res_pre = out - data_b #獲取input0的擴(kuò)展結(jié)果,用于驗(yàn)證實(shí)際值 out_tf = res_pre.reshape(broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]) data_b_tmp = data_a.reshape(input0_shape[0]*input0_shape[1]*input0_shape[2]*input0_shape[3]) print "out_tf" print out_tf f_dets = open("pre_data.dat", "w") for k in out_tf: b = float(k) a = '{:.10f}'.format(b) f_dets.write(str(a) + '\n') f_dets.close() out_res = [0]*broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3] #進(jìn)行數(shù)據(jù)擴(kuò)展 for i in range(broadcast_shape[0]): for j in range(broadcast_shape[1]): for k in range(broadcast_shape[2]): for m in range(broadcast_shape[3]): tmp_idx0 = i*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3] \ + j*broadcast_shape[2]*broadcast_shape[3] + k*broadcast_shape[3] + m ii = 0 jj = 0 kk = 0 mm = 0 if i >= input0_shape[0]: ii = input0_shape[0] -1 else: ii = i if j >= input0_shape[1]: jj = input0_shape[1] -1 else: jj = j if k >= input0_shape[2]: kk = input0_shape[2] -1 else: kk = k if m >= input0_shape[3]: mm = input0_shape[3] -1 else: mm = m tmp_idx1 = ii*input0_shape[1]*input0_shape[2]*input0_shape[3] \ + jj*input0_shape[2]*input0_shape[3] + kk*input0_shape[3] + mm #print mm out_res[tmp_idx0] = data_b_tmp[tmp_idx1] f_dets = open("aft_data.dat", "w") for k in out_res: b = float(k) a = '{:.10f}'.format(b) f_dets.write(str(a) + '\n') f_dets.close() #對(duì)比 print "compare" print out_res - out_tf
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
scipy.interpolate插值方法實(shí)例講解
這篇文章主要介紹了scipy.interpolate插值方法介紹,本文結(jié)合實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-12-12Python實(shí)現(xiàn)自定義異常堆棧信息的示例代碼
當(dāng)我們的程序報(bào)錯(cuò)時(shí),解釋器會(huì)將整個(gè)異常的堆棧信息全部輸出出來。解釋器會(huì)將異常產(chǎn)生的整個(gè)調(diào)用鏈都給打印出來,那么問題來了,我們能不能自定義這些報(bào)錯(cuò)信息呢?本文就來為大家詳細(xì)講講2022-07-07使用python實(shí)現(xiàn)數(shù)據(jù)篩查
一般數(shù)據(jù)篩查可以通過Python中的pandas庫(kù)來實(shí)現(xiàn),下面小編就來為大家介紹一下Python如何利用pandas實(shí)現(xiàn)數(shù)據(jù)篩查,感興趣的小伙伴可以一起學(xué)習(xí)一下2023-10-10Python讀取含url圖片鏈接的txt文檔方法小結(jié)
這篇文章主要為大家詳細(xì)介紹了三種Python讀取含url圖片鏈接的txt文檔方法,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2024-04-04Django 設(shè)置多環(huán)境配置文件載入問題
這篇文章主要介紹了Django 設(shè)置多環(huán)境配置文件載入問題,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02python編程培訓(xùn) python培訓(xùn)靠譜嗎
現(xiàn)在大家都知道,比較火的編程語(yǔ)言就是python了,很多朋友都想學(xué)習(xí)python編程,想上一個(gè)好的python培訓(xùn)班,小編今天給大家全面分析一下關(guān)于python編程培訓(xùn)方面的問題,希望能給你答疑解惑。2018-01-01