解決Keras中Embedding層masking與Concatenate層不可調(diào)和的問題
問題描述
我在用Keras的Embedding層做nlp相關(guān)的實(shí)現(xiàn)時(shí),發(fā)現(xiàn)了一個(gè)神奇的問題,先上代碼:
a = Input(shape=[15]) # None*15 b = Input(shape=[30]) # None*30 emb_a = Embedding(10, 5, mask_zero=True)(a) # None*15*5 emb_b = Embedding(20, 5, mask_zero=False)(b) # None*30*5 cat = Concatenate(axis=1)([emb_a, emb_b]) # None*45*5 model = Model(inputs=[a, b], outputs=[cat]) print model.summary()
我有兩個(gè)Embedding層,當(dāng)其中一個(gè)設(shè)置mask_zero=True,而另一個(gè)為False時(shí),會(huì)報(bào)如下錯(cuò)誤。
ValueError: Dimension 0 in both shapes must be equal, but are 1 and 5.
Shapes are [1] and [5]. for 'concatenate_1/concat_1' (op: 'ConcatV2')
with input shapes: [?,15,1], [?,30,5], [] and with computed input tensors: input[2] = <1>.
什么意思呢?是說在concatenate時(shí)發(fā)現(xiàn)兩個(gè)矩陣的第三維一個(gè)是1,一個(gè)是5,這就很神奇了,加了個(gè)mask_zero=True還會(huì)改變矩陣維度的嗎?
尋找問題根源
為了檢驗(yàn)Embedding層輸出的正確性,我把代碼改成了:
a = Input(shape=[30]) ... cat = Concatenate(axis=2)([emb_a, emb_b])
運(yùn)行成功了,并且summary顯示兩個(gè)Embedding層輸出矩陣的第三維都是5。
這就很奇怪了,明明沒有改變維度,為什么會(huì)報(bào)那樣的錯(cuò)誤?
然后我仔細(xì)追溯了一下前面的各項(xiàng)error,發(fā)現(xiàn)這么一句:
File ".../keras/layers/merge.py", line 374, in compute_mask
concatenated = K.concatenate(masks, axis=self.axis)
難道是mask的拼接有問題?
于是我修改了/keras/layers/merge.py里的Concatenate類的compute_mask函數(shù)(sudo vim就可以修改),在返回前輸出一下masks:
def compute_mask(self, inputs, mask=None): ... for x in masks: print x return ...
Tensor("concatenate_1/ExpandDims:0", shape=(?, 30, 1), dtype=bool)
Tensor("concatenate_1/Cast:0", shape=(?, 30, 5), dtype=bool)
發(fā)現(xiàn)了!有一個(gè)叫concatenate_1/ExpandDims:0的mask它的第三維度是1!
那么這個(gè)ExpandDims是什么鬼,觀察一下compute_mask代碼,發(fā)現(xiàn)了:
... elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it masks.append(K.expand_dims(mask_i)) ...
意思是當(dāng)mask_i的維度比input_i的維度小時(shí),擴(kuò)展一維,這下知道第三維的1是怎么來的了,那么可以預(yù)計(jì)compute_mask函數(shù)輸入的mask尺寸應(yīng)該是(None, 30),輸出一下試試:
def compute_mask(self, inputs, mask=None): print mask ...
[<tf.Tensor 'embedding_1/NotEqual:0' shape=(?, 30) dtype=bool>, None]
果然如此,總結(jié)一下問題的所在:
Embedding層的輸出會(huì)比輸入多一維,但Embedding生成的mask的維度與輸入一致。在Concatenate中,沒有mask的Embedding輸出被分配一個(gè)與該輸出相同維度的全1的mask,比有mask的Embedding的mask多一維。
提出解決方案
那么,Embedding層的mask到底是如何起作用的呢?是直接在Embedding層中起作用,還是在后續(xù)的層中起作用呢?縱觀embeddings.py,mask_zero只在compute_mask函數(shù)被用到:
def compute_mask(self, inputs, mask=None): if not self.mask_zero: return None else: return K.not_equal(inputs, 0)
可見,Embedding層的mask是記錄了Embedding輸入中非零元素的位置,并且傳給后面的支持masking的層,在后面的層里起作用。
一種最簡(jiǎn)單的解決方案:
給所有參與Concatenate的Embedding層都設(shè)置mask_zero=True。
但是,我想到了一種更靈活的解決方案:
修改embedding.py的compute_mask函數(shù),使得輸出的mask從2維變成3維,且第三維等于output_dim。
import tensorflow as tf ... def compute_mask(self, inputs, mask=None): if not self.mask_zero: return None else: mask = K.repeat(K.not_equal(inputs, 0), self.output_dim) # [?,output_dim,n] mask = tf.transpose(mask, [0,2,1]) # [?,n,output_dim] return mask ...
驗(yàn)證解決方案
為了驗(yàn)證這個(gè)改動(dòng)是否正確,我需要設(shè)計(jì)幾個(gè)小實(shí)驗(yàn)。
實(shí)驗(yàn)一:mask的正確性
我把輸出的mask做了改動(dòng),不知道m(xù)ask是否是正確的。
如下所示,數(shù)據(jù)是一個(gè)帶有3個(gè)樣本、樣本長(zhǎng)度最長(zhǎng)為3的補(bǔ)零padding過的矩陣,我分別讓Embedding層的mask_zero為False和True(為True時(shí)input_dim=|va|+2所以是5)。然后分別將Embedding的輸出在axis=1用MySumLayer進(jìn)行求和。為了方便觀察,我用keras.initializers.ones()把Embedding層的權(quán)值全部初始化為1。
# data data = np.array([[1,0,0], [1,2,0], [1,2,3]]) init = keras.initializers.ones() # network a = Input(shape=[3]) # None*3 emb1 = Embedding(4, 5, embeddings_initializer=init, mask_zero=False)(a) # None*3*5 emb2 = Embedding(5, 5, embeddings_initializer=init, mask_zero=True)(a) # None*3*5 sum1 = MySumLayer(axis=1)(emb1) # None*5 sum2 = MySumLayer(axis=1)(emb2) # None*5 model = Model(inputs=[a], outputs=[sum1, sum2]) # prediciton out = model.predict(data) for x in out: print x
結(jié)果如下:
[[3. 3. 3. 3. 3.] [3. 3. 3. 3. 3.] [3. 3. 3. 3. 3.]] [[1. 1. 1. 1. 1.] [2. 2. 2. 2. 2.] [3. 3. 3. 3. 3.]]
這個(gè)結(jié)果是正確的,這里解釋一波:
(1)當(dāng)mask_True=False時(shí),輸入矩陣中的0也會(huì)被認(rèn)為是正確的index,從而從權(quán)值矩陣中抽出第0行作為該index的Embedding,而我的權(quán)值都是1,因此所有Embedding都是1,對(duì)axis=1求和,實(shí)際上是對(duì)word length這一軸求和,輸入的word length最長(zhǎng)為3,以致于輸出矩陣的元素都是3.
(2)當(dāng)mask_True=True時(shí),輸入矩陣中的0會(huì)被mask掉,而這個(gè)mask的操作是體現(xiàn)在MySumLayer中的,將輸入(3, 3, 5)與mask(3, 3, 5)逐元素相乘,再相加。第一個(gè)樣本只有一項(xiàng)非零,第二個(gè)有兩項(xiàng),第三個(gè)三項(xiàng),因此MySumLayer輸出的矩陣,各行元素分別是1,2,3.
另外附上MySumLayer的代碼,它的功能是指定一個(gè)axis將Tensor進(jìn)行求和:
from keras import backend as K from keras.engine.topology import Layer import tensorflow as tf class MySumLayer(Layer): def __init__(self, axis, **kwargs): self.supports_masking = True self.axis = axis super(MySumLayer, self).__init__(**kwargs) def compute_mask(self, input, input_mask=None): # do not pass the mask to the next layers return None def call(self, x, mask=None): if mask is not None: # mask (batch, time) mask = K.cast(mask, K.floatx()) if K.ndim(x)!=K.ndim(mask): mask = K.repeat(mask, x.shape[-1]) mask = tf.transpose(mask, [0,2,1]) x = x * mask return K.sum(x, axis=self.axis) else: return K.sum(x, axis=self.axis) def compute_output_shape(self, input_shape): # remove temporal dimension if self.axis==1: return input_shape[0], input_shape[2] if self.axis==2: return input_shape[0], input_shape[1]
實(shí)驗(yàn)二:一個(gè)mask_zero=True和一個(gè)mask_zero=False的Embedding是否能夠拼接
a = Input(shape=[3]) # None*3 b = Input(shape=[4]) # None*4 emba = Embedding(4, 5, embeddings_initializer=init, mask_zero=False)(a) # None*3*5 embb = Embedding(6, 5, embeddings_initializer=init, mask_zero=True)(b) # None*4*5 cat = Concatenate(axis=1)([emba, embb]) # None*7*5 model = Model(inputs=[a,b], outputs=[cat]) print model.summary()
沒有報(bào)錯(cuò)!而且輸出的shape正是(None, 7, 5)。
實(shí)驗(yàn)三:兩個(gè)mask_zero=True的Embedding拼接是否會(huì)報(bào)錯(cuò)
a = Input(shape=[3]) # None*3 b = Input(shape=[4]) # None*4 emba = Embedding(4, 5, embeddings_initializer=init, mask_zero=True)(a) # None*3*5 embb = Embedding(6, 5, embeddings_initializer=init, mask_zero=True)(b) # None*4*5 cat = Concatenate(axis=1)([emba, embb]) # None*7*5 model = Model(inputs=[a,b], outputs=[cat]) print model.summary()
沒有報(bào)錯(cuò)!
實(shí)驗(yàn)四:兩個(gè)mask_zero=True的Embedding拼接結(jié)果是否正確
如下所示,第一個(gè)矩陣是一個(gè)帶有4個(gè)樣本、樣本長(zhǎng)度最長(zhǎng)為3的補(bǔ)零padding過的矩陣,第二個(gè)矩陣是一個(gè)帶有4個(gè)樣本、樣本長(zhǎng)度最長(zhǎng)為4的補(bǔ)零padding過的矩陣。為什么這里要求樣本個(gè)數(shù)一致呢,因?yàn)橐话銇碚f需要這種拼接操作的都是同一批樣本的不同特征。兩者的Embedding都設(shè)置mask_zero=True,在axis=1拼接后,用MySumLayer在axis=1加起來。
# data data1 = np.array([[1,0,0], [1,2,0], [1,2,3], [1,2,3]]) data2 = np.array([[1,0,0,0], [1,2,0,0], [1,2,3,0], [1,2,3,4]]) init = keras.initializers.ones() # network a = Input(shape=[3]) # None*3 b = Input(shape=[4]) # None*4 emba = Embedding(4, 5, embeddings_initializer=init, mask_zero=True)(a) # None*3*5 embb = Embedding(6, 5, embeddings_initializer=init, mask_zero=True)(b) # None*3*5 cat = Concatenate(axis=1)([emba, embb]) su = MySumLayer(axis=1)(cat) model = Model(inputs=[a,b], outputs=[su]) # prediction print model.predict([data1, data2])
輸出如下
[[2. 2. 2. 2. 2.] [4. 4. 4. 4. 4.] [6. 6. 6. 6. 6.] [7. 7. 7. 7. 7.]]
這個(gè)結(jié)果是正確的,解釋一波,其實(shí)兩個(gè)矩陣橫向拼接起來是下面這樣的,4個(gè)樣本分別有2、4、6、7個(gè)非零index,而Embedding層權(quán)值都是1,所以最終輸出的就是上面這個(gè)樣子。
# index 1 0 0 1 0 0 0 1 2 0 1 2 0 0 1 2 3 1 2 3 0 1 2 3 1 2 3 4
至此,問題成功解決了。
以上這篇解決Keras中Embedding層masking與Concatenate層不可調(diào)和的問題就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python基礎(chǔ)入門詳解(文件輸入/輸出 內(nèi)建類型 字典操作使用方法)
這篇文章主要介紹了python基礎(chǔ)入門,包括文件輸入/輸出、內(nèi)建類型、字典操作等使用方法2013-12-12tensorflow pb to tflite 精度下降詳解
這篇文章主要介紹了tensorflow pb to tflite 精度下降詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05python實(shí)現(xiàn)封裝得到virustotal掃描結(jié)果
這篇文章主要介紹了python實(shí)現(xiàn)封裝得到virustotal掃描結(jié)果的方法,是比較實(shí)用的技巧,可將掃描結(jié)果寫入數(shù)據(jù)庫,需要的朋友可以參考下2014-10-10Python寫的一個(gè)簡(jiǎn)單監(jiān)控系統(tǒng)
這篇文章主要介紹了Python寫的一個(gè)簡(jiǎn)單監(jiān)控系統(tǒng),本文講解了詳細(xì)的編碼步驟,并給給出相應(yīng)的實(shí)現(xiàn)代碼,需要的朋友可以參考下2015-06-06Django中l(wèi)ogin_required裝飾器的深入介紹
這篇文章主要給大家介紹了關(guān)于Django中l(wèi)ogin_required裝飾器的使用方法,并給大家進(jìn)行了實(shí)例借鑒,利用@login_required實(shí)現(xiàn)Django用戶登陸訪問限制,文中通過示例代碼介紹的非常詳細(xì),需要的朋友可以參考借鑒,下面來一起看看吧。2017-11-11利用Python實(shí)現(xiàn)多種風(fēng)格的照片處理
這篇文章主要為大家詳細(xì)介紹了如何利用Python一鍵實(shí)現(xiàn)多種風(fēng)格的照片處理并制作可視化GUI界面,文中的示例代碼講解詳細(xì),感興趣的可以了解一下2022-07-07python自定義函數(shù)實(shí)現(xiàn)一個(gè)數(shù)的三次方計(jì)算方法
今天小編就為大家分享一篇python自定義函數(shù)實(shí)現(xiàn)一個(gè)數(shù)的三次方計(jì)算方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-01-01python tkinter制作用戶登錄界面的簡(jiǎn)單實(shí)現(xiàn)
這篇文章主要介紹了python tkinter制作用戶登錄界面的簡(jiǎn)單實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2021-04-04Python嵌套列表轉(zhuǎn)一維的方法(壓平嵌套列表)
今天小編就為大家分享一篇Python嵌套列表轉(zhuǎn)一維的方法(壓平嵌套列表),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-07-07