Keras自定義實現(xiàn)帶masking的meanpooling層方式
Keras確實是一大神器,代碼可以寫得非常簡潔,但是最近在寫LSTM和DeepFM的時候,遇到了一個問題:樣本的長度不一樣。對不定長序列的一種預(yù)處理方法是,首先對數(shù)據(jù)進行padding補0,然后引入keras的Masking層,它能自動對0值進行過濾。
問題在于keras的某些層不支持Masking層處理過的輸入數(shù)據(jù),例如Flatten、AveragePooling1D等等,而其中meanpooling是我需要的一個運算。例如LSTM對每一個序列的輸出長度都等于該序列的長度,那么均值運算就只應(yīng)該除以序列長度,而不是padding后的最長長度。
例如下面這個 3x4 大小的張量,經(jīng)過補零padding的。我希望做axis=1的meanpooling,則第一行應(yīng)該是 (10+20)/2,第二行應(yīng)該是 (10+20+30)/3,第三行應(yīng)該是 (10+20+30+40)/4。
Keras如何自定義層
在 Keras2.0 版本中(如果你使用的是舊版本請更新),自定義一個層的方法參考這里。具體地,你只要實現(xiàn)三個方法即可。
build(input_shape) : 這是你定義層參數(shù)的地方。這個方法必須設(shè)self.built = True,可以通過調(diào)用super([Layer], self).build()完成。如果這個層沒有需要訓練的參數(shù),可以不定義。
call(x) : 這里是編寫層的功能邏輯的地方。你只需要關(guān)注傳入call的第一個參數(shù):輸入張量,除非你希望你的層支持masking。
compute_output_shape(input_shape) : 如果你的層更改了輸入張量的形狀,你應(yīng)該在這里定義形狀變化的邏輯,這讓Keras能夠自動推斷各層的形狀。
下面是一個簡單的例子:
from keras import backend as K from keras.engine.topology import Layer import numpy as np class MyLayer(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): # Create a trainable weight variable for this layer. self.kernel = self.add_weight(name='kernel', shape=(input_shape[1], self.output_dim), initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) # Be sure to call this somewhere! def call(self, x): return K.dot(x, self.kernel) def compute_output_shape(self, input_shape): return (input_shape[0], self.output_dim)
Keras自定義層如何允許masking
觀察了一些支持masking的層,發(fā)現(xiàn)他們對masking的支持體現(xiàn)在兩方面。
在 __init__ 方法中設(shè)置 supports_masking=True。
實現(xiàn)一個compute_mask方法,用于將mask傳到下一層。
部分層會在call中調(diào)用傳入的mask。
自定義實現(xiàn)帶masking的meanpooling
假設(shè)輸入是3d的。首先,在__init__方法中設(shè)置self.supports_masking = True,然后在call中實現(xiàn)相應(yīng)的計算。
from keras import backend as K from keras.engine.topology import Layer import tensorflow as tf class MyMeanPool(Layer): def __init__(self, axis, **kwargs): self.supports_masking = True self.axis = axis super(MyMeanPool, self).__init__(**kwargs) def compute_mask(self, input, input_mask=None): # need not to pass the mask to next layers return None def call(self, x, mask=None): if mask is not None: mask = K.repeat(mask, x.shape[-1]) mask = tf.transpose(mask, [0,2,1]) mask = K.cast(mask, K.floatx()) x = x * mask return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis) else: return K.mean(x, axis=self.axis) def compute_output_shape(self, input_shape): output_shape = [] for i in range(len(input_shape)): if i!=self.axis: output_shape.append(input_shape[i]) return tuple(output_shape)
使用舉例:
from keras.layers import Input, Masking from keras.models import Model from MyMeanPooling import MyMeanPool data = [[[10,10],[0, 0 ],[0, 0 ],[0, 0 ]], [[10,10],[20,20],[0, 0 ],[0, 0 ]], [[10,10],[20,20],[30,30],[0, 0 ]], [[10,10],[20,20],[30,30],[40,40]]] A = Input(shape=[4,2]) # None * 4 * 2 mA = Masking()(A) out = MyMeanPool(axis=1)(mA) model = Model(inputs=[A], outputs=[out]) print model.summary() print model.predict(data)
結(jié)果如下,每一行對應(yīng)一個樣本的結(jié)果,例如第一個樣本只有第一個時刻有值,輸出結(jié)果是[10. 10. ],是正確的。
[[10. 10.] [15. 15.] [20. 20.] [25. 25.]]
在DeepFM中,每個樣本都是由ID構(gòu)成的,多值field往往會導(dǎo)致樣本長度不一的情況,例如interest這樣的field,同一個樣本可能在該field中有多項取值,畢竟每個人的興趣點不止一項。
采取padding的方法將每個field的特征補長到最長的長度,則數(shù)據(jù)尺寸是 [batch_size, max_timestep],經(jīng)過Embedding為每個樣本的每個特征ID配一個latent vector,數(shù)據(jù)尺寸將變?yōu)?[batch_size, max_timestep,latent_dim]。
我們希望每一個field的Embedding之后的尺寸為[batch_size, latent_dim],然后進行concat操作橫向拼接,所以這里就可以使用自定義的MeanPool層了。希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
PyCharm-錯誤-找不到指定文件python.exe的解決方法
今天小編就為大家分享一篇PyCharm-錯誤-找不到指定文件python.exe的解決方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07Python中try excpet BaseException(異常處理捕獲)的使用
本文主要介紹了Python中try excpet BaseException(異常處理捕獲)的使用,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-03-03macOS M1(Apple Silicon)安裝配置Conda環(huán)境的具體實現(xiàn)
由于常用的Anaconda和Miniconda現(xiàn)在都沒有提供M1處理器支持的conda環(huán)境,以下是conda-forge提供的miniforge,感興趣的可以了解一下2021-08-08python?PyQt5中QButtonGroup的詳細用法解析與應(yīng)用實戰(zhàn)記錄
在PyQt5中,QButtonGroup是一個用于管理按鈕互斥性和信號槽連接的類,它可以將多個按鈕劃分為一個組,管理按鈕的選中狀態(tài)和ID,本文詳細介紹了QButtonGroup的創(chuàng)建、使用方法和實際應(yīng)用案例,適合需要在PyQt5項目中高效管理按鈕組的開發(fā)者2024-10-10Python2和Python3之間的str處理方式導(dǎo)致亂碼的講解
今天小編就為大家分享一篇關(guān)于Python2和Python3之間的str處理方式導(dǎo)致亂碼的講解,小編覺得內(nèi)容挺不錯的,現(xiàn)在分享給大家,具有很好的參考價值,需要的朋友一起跟隨小編來看看吧2019-01-01Python協(xié)程操作之gevent(yield阻塞,greenlet),協(xié)程實現(xiàn)多任務(wù)(有規(guī)律的交替協(xié)作執(zhí)行)用法詳解
這篇文章主要介紹了Python協(xié)程操作之gevent(yield阻塞,greenlet),協(xié)程實現(xiàn)多任務(wù)(有規(guī)律的交替協(xié)作執(zhí)行)用法,結(jié)合實例形式較為詳細的分析了協(xié)程的功能、原理及gevent、greenlet實現(xiàn)協(xié)程,以及協(xié)程實現(xiàn)多任務(wù)相關(guān)操作技巧,需要的朋友可以參考下2019-10-10