Python實現(xiàn)雙向RNN與堆疊的雙向RNN的示例代碼
1、雙向RNN
雙向RNN(Bidirectional RNN)的結(jié)構(gòu)如下圖所示。
雙向的 RNN 是同時考慮“過去”和“未來”的信息。上圖是一個序列長度為 4 的雙向RNN 結(jié)構(gòu)。
雙向RNN就像是我們做閱讀理解的時候從頭向后讀一遍文章,然后又從后往前讀一遍文章,然后再做題。有可能從后往前再讀一遍文章的時候會有新的不一樣的理解,最后模型可能會得到更好的結(jié)果。
2、堆疊的雙向RNN
堆疊的雙向RNN(Stacked Bidirectional RNN)的結(jié)構(gòu)如上圖所示。上圖是一個堆疊了3個隱藏層的RNN網(wǎng)絡(luò)。
注意,這里的堆疊的雙向RNN并不是只有雙向的RNN才可以堆疊,其實任意的RNN都可以堆疊,如SimpleRNN、LSTM和GRU這些循環(huán)神經(jīng)網(wǎng)絡(luò)也可以進(jìn)行堆疊。
堆疊指的是在RNN的結(jié)構(gòu)中疊加多層,類似于BP神經(jīng)網(wǎng)絡(luò)中可以疊加多層,增加網(wǎng)絡(luò)的非線性。
3、雙向LSTM實現(xiàn)MNIST數(shù)據(jù)集分類
import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from tensorflow.keras.layers import LSTM,Dropout,Bidirectional from tensorflow.keras.optimizers import Adam import matplotlib.pyplot as plt # 載入數(shù)據(jù)集 mnist = tf.keras.datasets.mnist # 載入數(shù)據(jù),數(shù)據(jù)載入的時候就已經(jīng)劃分好訓(xùn)練集和測試集 # 訓(xùn)練集數(shù)據(jù)x_train的數(shù)據(jù)形狀為(60000,28,28) # 訓(xùn)練集標(biāo)簽y_train的數(shù)據(jù)形狀為(60000) # 測試集數(shù)據(jù)x_test的數(shù)據(jù)形狀為(10000,28,28) # 測試集標(biāo)簽y_test的數(shù)據(jù)形狀為(10000) (x_train, y_train), (x_test, y_test) = mnist.load_data() # 對訓(xùn)練集和測試集的數(shù)據(jù)進(jìn)行歸一化處理,有助于提升模型訓(xùn)練速度 x_train, x_test = x_train / 255.0, x_test / 255.0 # 把訓(xùn)練集和測試集的標(biāo)簽轉(zhuǎn)為獨熱編碼 y_train = tf.keras.utils.to_categorical(y_train,num_classes=10) y_test = tf.keras.utils.to_categorical(y_test,num_classes=10) # 數(shù)據(jù)大小-一行有28個像素 input_size = 28 # 序列長度-一共有28行 time_steps = 28 # 隱藏層memory block個數(shù) cell_size = 50 # 創(chuàng)建模型 # 循環(huán)神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)輸入必須是3維數(shù)據(jù) # 數(shù)據(jù)格式為(數(shù)據(jù)數(shù)量,序列長度,數(shù)據(jù)大小) # 載入的mnist數(shù)據(jù)的格式剛好符合要求 # 注意這里的input_shape設(shè)置模型數(shù)據(jù)輸入時不需要設(shè)置數(shù)據(jù)的數(shù)量 model = Sequential([ Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)), Dropout(0.2), Bidirectional(LSTM(cell_size)), Dropout(0.2), # 50個memory block輸出的50個值跟輸出層10個神經(jīng)元全連接 Dense(10,activation=tf.keras.activations.softmax) ]) # 循環(huán)神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)輸入必須是3維數(shù)據(jù) # 數(shù)據(jù)格式為(數(shù)據(jù)數(shù)量,序列長度,數(shù)據(jù)大小) # 載入的mnist數(shù)據(jù)的格式剛好符合要求 # 注意這里的input_shape設(shè)置模型數(shù)據(jù)輸入時不需要設(shè)置數(shù)據(jù)的數(shù)量 # model.add(LSTM( # units = cell_size, # input_shape = (time_steps,input_size), # )) # 50個memory block輸出的50個值跟輸出層10個神經(jīng)元全連接 # model.add(Dense(10,activation='softmax')) # 定義優(yōu)化器 adam = Adam(lr=1e-3) # 定義優(yōu)化器,loss function,訓(xùn)練過程中計算準(zhǔn)確率 使用交叉熵?fù)p失函數(shù) model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy']) # 訓(xùn)練模型 history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test)) #打印模型摘要 model.summary() loss=history.history['loss'] val_loss=history.history['val_loss'] accuracy=history.history['accuracy'] val_accuracy=history.history['val_accuracy'] # 繪制loss曲線 plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() # 繪制acc曲線 plt.plot(accuracy, label='Training accuracy') plt.plot(val_accuracy, label='Validation accuracy') plt.title('Training and Validation Loss') plt.legend() plt.show()
這個可能對文本數(shù)據(jù)比較容易處理,這里用這個模型有點勉強(qiáng),只是簡單測試下。
模型摘要:
acc曲線:
loss曲線:
到此這篇關(guān)于Python實現(xiàn)雙向RNN與堆疊的雙向RNN的示例代碼的文章就介紹到這了,更多相關(guān)Python 雙向RNN內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python如何讀寫二進(jìn)制數(shù)組數(shù)據(jù)
這篇文章主要介紹了Python如何讀寫二進(jìn)制數(shù)組數(shù)據(jù),文中講解非常細(xì)致,代碼幫助大家更好的理解和學(xué)習(xí),感興趣的朋友可以了解下2020-08-08解決pycharm無法刪除invalid interpreter(無效解析器)的問題
這篇文章主要介紹了pycharm無法刪除invalid interpreter(無效解析器)的問題,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-07-07解決pycharm上的jupyter notebook端口被占用問題
今天小編就為大家分享一篇解決pycharm上的jupyter notebook端口被占用問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-12-12pycharm中使用pyplot時報錯MatplotlibDeprecationWarning
最近在使用Pycharm中matplotlib作圖處理時報錯,所以這篇文章主要給大家介紹了關(guān)于pycharm中使用pyplot時報錯MatplotlibDeprecationWarning的相關(guān)資料,需要的朋友可以參考下2023-12-12Python實現(xiàn)將照片變成卡通圖片的方法【基于opencv】
這篇文章主要介紹了Python實現(xiàn)將照片變成卡通圖片的方法,涉及Python基于opencv庫進(jìn)行圖片處理的相關(guān)操作技巧,需要的朋友可以參考下2018-01-01