python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量
什么是GRU
GRU是LSTM的一個(gè)變種。
傳承了LSTM的門結(jié)構(gòu),但是將LSTM的三個(gè)門轉(zhuǎn)化成兩個(gè)門,分別是更新門和重置門。
1、GRU單元的輸入與輸出
下圖是每個(gè)GRU單元的結(jié)構(gòu)。
在n時(shí)刻,每個(gè)GRU單元的輸入有兩個(gè):
- 當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值Xt;
- 上一時(shí)刻GRU的輸出值ht-1;
輸出有一個(gè):
當(dāng)前時(shí)刻GRU輸出值ht;
2、GRU的門結(jié)構(gòu)
GRU含有兩個(gè)門結(jié)構(gòu),分別是:
更新門zt和重置門rt:
更新門用于控制前一時(shí)刻的狀態(tài)信息被代入到當(dāng)前狀態(tài)的程度,更新門的值越大說(shuō)明前一時(shí)刻的狀態(tài)信息帶入越少,這一時(shí)刻的狀態(tài)信息帶入越多。
重置門用于控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說(shuō)明忽略得越多。
3、GRU的參數(shù)量計(jì)算
a、更新門
更新門在圖中的標(biāo)號(hào)為zt,需要結(jié)合ht-1和Xt來(lái)決定上一時(shí)刻的輸出ht-1有多少得到保留,更新門的值越大說(shuō)明前一時(shí)刻的狀態(tài)信息保留越少,這一時(shí)刻的狀態(tài)信息保留越多。
結(jié)合公式我們可以知道:
zt由ht-1和Xt來(lái)決定。
當(dāng)更新門zt的值較大的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,而這一時(shí)刻的狀態(tài)信息保留較多。
b、重置門
重置門在圖中的標(biāo)號(hào)為rt,需要結(jié)合ht-1和Xt來(lái)控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說(shuō)明忽略得越多。
結(jié)合公式我們可以知道:
rt由ht-1和Xt來(lái)決定。
當(dāng)重置門rt的值較小的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,說(shuō)明忽略得越多。
c、全部參數(shù)量
所以所有的門總參數(shù)量為:
在Keras中實(shí)現(xiàn)GRU
GRU一般需要輸入兩個(gè)參數(shù)。
一個(gè)是unit、一個(gè)是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神經(jīng)元的數(shù)量。
input_shape用于指定輸入的shape,分別指定TIME_STEPS和INPUT_SIZE。
實(shí)現(xiàn)代碼
import numpy as np from keras.models import Sequential from keras.layers import Input,Activation,Dense from keras.models import Model from keras.datasets import mnist from keras.layers.recurrent import GRU from keras.utils import np_utils from keras.optimizers import Adam TIME_STEPS = 28 INPUT_SIZE = 28 BATCH_SIZE = 50 index_start = 0 OUTPUT_SIZE = 10 CELL_SIZE = 75 LR = 1e-3 (X_train,Y_train),(X_test,Y_test) = mnist.load_data() X_train = X_train.reshape(-1,28,28)/255 X_test = X_test.reshape(-1,28,28)/255 Y_train = np_utils.to_categorical(Y_train,num_classes= 10) Y_test = np_utils.to_categorical(Y_test,num_classes= 10) inputs = Input(shape=[TIME_STEPS,INPUT_SIZE]) x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs) x = Dense(OUTPUT_SIZE)(x) x = Activation("softmax")(x) model = Model(inputs,x) adam = Adam(LR) model.summary() model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy']) for i in range(50000): X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:] Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:] index_start += BATCH_SIZE cost = model.train_on_batch(X_batch,Y_batch) if index_start >= X_train.shape[0]: index_start = 0 if i%100 == 0: cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50) print("accuracy:",accuracy)
實(shí)現(xiàn)效果:
10000/10000 [==============================] - 2s 231us/step accuracy: 0.16749999986961484 10000/10000 [==============================] - 2s 206us/step accuracy: 0.6134000015258789 10000/10000 [==============================] - 2s 214us/step accuracy: 0.7058000019192696 10000/10000 [==============================] - 2s 209us/step accuracy: 0.797899999320507
以上就是python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量的詳細(xì)內(nèi)容,更多關(guān)于Keras實(shí)現(xiàn)GRU參數(shù)量的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python3+PyQt5圖形項(xiàng)的自定義和交互 python3實(shí)現(xiàn)page Designer應(yīng)用程序
這篇文章主要為大家詳細(xì)介紹了python3+PyQt5圖形項(xiàng)的自定義和交互,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-04-04python 爬取華為應(yīng)用市場(chǎng)評(píng)論
項(xiàng)目需要爬取評(píng)論數(shù)據(jù),在此做一個(gè)記錄,這里爬取的是web端的數(shù)據(jù),以后可能會(huì)考慮爬取android app中的數(shù)據(jù)。2021-05-05Python數(shù)據(jù)分析之雙色球統(tǒng)計(jì)兩個(gè)紅和藍(lán)球哪組合比例高的方法
這篇文章主要介紹了Python數(shù)據(jù)分析之雙色球統(tǒng)計(jì)兩個(gè)紅和藍(lán)球哪組合比例高的方法,涉及Python數(shù)值運(yùn)算及圖形繪制相關(guān)操作技巧,需要的朋友可以參考下2018-02-02Python3 socket即時(shí)通訊腳本實(shí)現(xiàn)代碼實(shí)例(threading多線程)
這篇文章主要介紹了Python3 socket即時(shí)通訊腳本實(shí)現(xiàn)代碼實(shí)例(threading多線程),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06Python運(yùn)維之獲取系統(tǒng)CPU信息的實(shí)現(xiàn)方法
今天小編就為大家分享一篇Python運(yùn)維之獲取系統(tǒng)CPU信息的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-06-06python 循環(huán)讀取txt文檔 并轉(zhuǎn)換成csv的方法
今天小編就為大家分享一篇python 循環(huán)讀取txt文檔 并轉(zhuǎn)換成csv的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-10-10深度學(xué)習(xí)小工程練習(xí)之tensorflow垃圾分類詳解
這篇文章主要介紹了練習(xí)深度學(xué)習(xí)的一個(gè)小工程,代碼簡(jiǎn)單明確,用來(lái)作為學(xué)習(xí)深度學(xué)習(xí)的練習(xí)很適合,對(duì)于有需要的朋友可以參考下,希望大家可以體驗(yàn)到深度學(xué)習(xí)帶來(lái)的收獲2021-04-04Python sklearn KFold 生成交叉驗(yàn)證數(shù)據(jù)集的方法
今天小編就為大家分享一篇Python sklearn KFold 生成交叉驗(yàn)證數(shù)據(jù)集的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-12-12