欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量

 更新時(shí)間:2022年05月07日 10:34:27   作者:Bubbliiiing  
這篇文章主要為大家介紹了python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

什么是GRU

GRU是LSTM的一個(gè)變種。

傳承了LSTM的門結(jié)構(gòu),但是將LSTM的三個(gè)門轉(zhuǎn)化成兩個(gè)門,分別是更新門和重置門。

1、GRU單元的輸入與輸出

下圖是每個(gè)GRU單元的結(jié)構(gòu)。

在n時(shí)刻,每個(gè)GRU單元的輸入有兩個(gè):

  • 當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值Xt;
  • 上一時(shí)刻GRU的輸出值ht-1;

輸出有一個(gè):

當(dāng)前時(shí)刻GRU輸出值ht;

2、GRU的門結(jié)構(gòu)

GRU含有兩個(gè)門結(jié)構(gòu),分別是:

更新門zt和重置門rt:

更新門用于控制前一時(shí)刻的狀態(tài)信息被代入到當(dāng)前狀態(tài)的程度,更新門的值越大說(shuō)明前一時(shí)刻的狀態(tài)信息帶入越少,這一時(shí)刻的狀態(tài)信息帶入越多。

重置門用于控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說(shuō)明忽略得越多。

3、GRU的參數(shù)量計(jì)算

a、更新門

更新門在圖中的標(biāo)號(hào)為zt,需要結(jié)合ht-1和Xt來(lái)決定上一時(shí)刻的輸出ht-1有多少得到保留,更新門的值越大說(shuō)明前一時(shí)刻的狀態(tài)信息保留越少,這一時(shí)刻的狀態(tài)信息保留越多。

結(jié)合公式我們可以知道:

zt由ht-1和Xt來(lái)決定。

當(dāng)更新門zt的值較大的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,而這一時(shí)刻的狀態(tài)信息保留較多。

b、重置門

重置門在圖中的標(biāo)號(hào)為rt,需要結(jié)合ht-1和Xt來(lái)控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說(shuō)明忽略得越多。

結(jié)合公式我們可以知道:

rt由ht-1和Xt來(lái)決定。

當(dāng)重置門rt的值較小的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,說(shuō)明忽略得越多。

c、全部參數(shù)量

所以所有的門總參數(shù)量為:

在Keras中實(shí)現(xiàn)GRU

GRU一般需要輸入兩個(gè)參數(shù)。

一個(gè)是unit、一個(gè)是input_shape。

LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))

unit用于指定神經(jīng)元的數(shù)量。

input_shape用于指定輸入的shape,分別指定TIME_STEPS和INPUT_SIZE。

實(shí)現(xiàn)代碼

import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
    X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
    Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
    index_start += BATCH_SIZE
    cost = model.train_on_batch(X_batch,Y_batch)
    if index_start >= X_train.shape[0]:
        index_start = 0
    if i%100 == 0:
        cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
        print("accuracy:",accuracy)

實(shí)現(xiàn)效果:

10000/10000 [==============================] - 2s 231us/step
accuracy: 0.16749999986961484
10000/10000 [==============================] - 2s 206us/step
accuracy: 0.6134000015258789
10000/10000 [==============================] - 2s 214us/step
accuracy: 0.7058000019192696
10000/10000 [==============================] - 2s 209us/step
accuracy: 0.797899999320507

以上就是python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量的詳細(xì)內(nèi)容,更多關(guān)于Keras實(shí)現(xiàn)GRU參數(shù)量的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

最新評(píng)論