Python+SimpleRNN實(shí)現(xiàn)股票預(yù)測詳解
原理請(qǐng)查看前面幾篇文章。
1、數(shù)據(jù)源
SH600519.csv 是用 tushare 模塊下載的 SH600519 貴州茅臺(tái)的日 k 線數(shù)據(jù),本次例子中只用它的 C 列數(shù)據(jù)(如圖 所示):
用連續(xù) 60 天的開盤價(jià),預(yù)測第 61 天的開盤價(jià)。
2、代碼實(shí)現(xiàn)
按照六步法: import 相關(guān)模塊->讀取貴州茅臺(tái)日 k 線數(shù)據(jù)到變量 maotai,把變量 maotai 中前 2126 天數(shù)據(jù)中的開盤價(jià)作為訓(xùn)練數(shù)據(jù),把變量 maotai 中后 300 天數(shù)據(jù)中的開盤價(jià)作為測試數(shù)據(jù);然后對(duì)開盤價(jià)進(jìn)行歸一化,使送入神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)分布在 0 到 1 之間;
接下來建立空列表分別用于接收訓(xùn)練集輸入特征、訓(xùn)練集標(biāo)簽、測試集輸入特征、測試集標(biāo)簽;
繼續(xù)構(gòu)造數(shù)據(jù)。用 for 循環(huán)遍歷整個(gè)訓(xùn)練數(shù)據(jù),每連續(xù)60 天數(shù)據(jù)作為輸入特征 x_train,第 61 天數(shù)據(jù)作為對(duì)應(yīng)的標(biāo)簽 y_train ,一共生成 2066 組訓(xùn)練數(shù)據(jù),然后打亂訓(xùn)練數(shù)據(jù)的順序并轉(zhuǎn)變?yōu)?array 格式繼而轉(zhuǎn)變?yōu)?RNN 輸入要求的維度;
同理,利用 for 循環(huán)遍歷整個(gè)測試數(shù)據(jù),一共生成 240組測試數(shù)據(jù),測試集不需要打亂順序,但需轉(zhuǎn)變?yōu)?array 格式繼而轉(zhuǎn)變?yōu)?RNN 輸入要求的維度。
用 sequntial 搭建神經(jīng)網(wǎng)絡(luò):
第一層循環(huán)計(jì)算層記憶體設(shè)定 80 個(gè),每個(gè)時(shí)間步推送 h t h_t ht?給下一層,使用 0.2 的 Dropout;
第二層循環(huán)計(jì)算層設(shè)定記憶體有 100 個(gè),僅最后的時(shí)間步推送 h t h_t ht?給下一層,使用 0.2 的 Dropout;
由于輸出值是第 61 天的開盤價(jià)只有一個(gè)數(shù),所以全連接 Dense 是 1->compile 配置訓(xùn)練方法使用 adam 優(yōu)化器,使用均方誤差損失函數(shù)。在股票預(yù)測代碼中,只需觀測 loss,訓(xùn)練迭代打印的時(shí)候也只打印 loss,所以這里就無需給metrics賦值->設(shè)置斷點(diǎn)續(xù)訓(xùn),fit 執(zhí)行訓(xùn)練過程->summary 打印出網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)統(tǒng)計(jì)。
進(jìn)行 loss 可視化與參數(shù)報(bào)錯(cuò)操作
進(jìn)行股票預(yù)測。用 predict 預(yù)測測試集數(shù)據(jù),然后將預(yù)測值和真實(shí)值從歸一化的數(shù)值變換到真實(shí)數(shù)值,最后用紅色線畫出真實(shí)值曲線 、用藍(lán)色線畫出預(yù)測值曲線。
為了評(píng)價(jià)模型優(yōu)劣,給出了三個(gè)評(píng)判指標(biāo):均方誤差、均方根誤差和平均絕對(duì)誤差,這些誤差越小說明預(yù)測的數(shù)值與真實(shí)值越接近。
RNN 股票預(yù)測 loss 曲線:
RNN 股票預(yù)測曲線:
RNN 股票預(yù)測評(píng)價(jià)指標(biāo):
模型摘要:
3、完整代碼
import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dropout, Dense, SimpleRNN import matplotlib.pyplot as plt import os import pandas as pd from sklearn.preprocessing import MinMaxScaler from sklearn.metrics import mean_squared_error, mean_absolute_error import math # 讀取股票文件 maotai = pd.read_csv('./SH600519.csv') # 前(2426-300=2126)天的開盤價(jià)作為訓(xùn)練集,表格從0開始計(jì)數(shù),2:3 是提取[2:3)列,前閉后開,故提取出C列開盤價(jià) training_set = maotai.iloc[0:2426 - 300, 2:3].values # 后300天的開盤價(jià)作為測試集 test_set = maotai.iloc[2426 - 300:, 2:3].values # 歸一化 sc = MinMaxScaler(feature_range=(0, 1)) # 定義歸一化:歸一化到(0,1)之間 training_set_scaled = sc.fit_transform(training_set) # 求得訓(xùn)練集的最大值,最小值這些訓(xùn)練集固有的屬性,并在訓(xùn)練集上進(jìn)行歸一化 test_set = sc.transform(test_set) # 利用訓(xùn)練集的屬性對(duì)測試集進(jìn)行歸一化 x_train = [] y_train = [] x_test = [] y_test = [] # 測試集:csv表格中前2426-300=2126天數(shù)據(jù) # 利用for循環(huán),遍歷整個(gè)訓(xùn)練集,提取訓(xùn)練集中連續(xù)60天的開盤價(jià)作為輸入特征x_train,第61天的數(shù)據(jù)作為標(biāo)簽,for循環(huán)共構(gòu)建2426-300-60=2066組數(shù)據(jù)。 for i in range(60, len(training_set_scaled)): x_train.append(training_set_scaled[i - 60:i, 0]) y_train.append(training_set_scaled[i, 0]) # 對(duì)訓(xùn)練集進(jìn)行打亂 np.random.seed(7) np.random.shuffle(x_train) np.random.seed(7) np.random.shuffle(y_train) tf.random.set_seed(7) # 將訓(xùn)練集由list格式變?yōu)閍rray格式 x_train, y_train = np.array(x_train), np.array(y_train) # 使x_train符合RNN輸入要求:[送入樣本數(shù), 循環(huán)核時(shí)間展開步數(shù), 每個(gè)時(shí)間步輸入特征個(gè)數(shù)]。 # 此處整個(gè)數(shù)據(jù)集送入,送入樣本數(shù)為x_train.shape[0]即2066組數(shù)據(jù);輸入60個(gè)開盤價(jià),預(yù)測出第61天的開盤價(jià),循環(huán)核時(shí)間展開步數(shù)為60; 每個(gè)時(shí)間步送入的特征是某一天的開盤價(jià),只有1個(gè)數(shù)據(jù),故每個(gè)時(shí)間步輸入特征個(gè)數(shù)為1 x_train = np.reshape(x_train, (x_train.shape[0], 60, 1)) # 測試集:csv表格中后300天數(shù)據(jù) # 利用for循環(huán),遍歷整個(gè)測試集,提取測試集中連續(xù)60天的開盤價(jià)作為輸入特征x_test,第61天的數(shù)據(jù)作為標(biāo)簽y_test,for循環(huán)共構(gòu)建300-60=240組數(shù)據(jù)。 for i in range(60, len(test_set)): x_test.append(test_set[i - 60:i, 0]) y_test.append(test_set[i, 0]) # 測試集變array并reshape為符合RNN輸入要求:[送入樣本數(shù), 循環(huán)核時(shí)間展開步數(shù), 每個(gè)時(shí)間步輸入特征個(gè)數(shù)] x_test, y_test = np.array(x_test), np.array(y_test) x_test = np.reshape(x_test, (x_test.shape[0], 60, 1)) model = tf.keras.Sequential([ SimpleRNN(80, return_sequences=True),# 第一層循環(huán)計(jì)算層:記憶體設(shè)定80個(gè),每個(gè)時(shí)間步推送ht給下一層 Dropout(0.2), #使用0.2的Dropout SimpleRNN(100),# 第二層循環(huán)計(jì)算層,設(shè)定記憶體100個(gè) Dropout(0.2), # Dense(1) # 由于輸出值是第61天的開盤價(jià),只有一個(gè)數(shù),所以Dense是1 ]) model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss='mean_squared_error') # 損失函數(shù)用均方誤差 # 該應(yīng)用只觀測loss數(shù)值,不觀測準(zhǔn)確率,所以刪去metrics選項(xiàng),一會(huì)在每個(gè)epoch迭代顯示時(shí)只顯示loss值 checkpoint_save_path = "./checkpoint/rnn_stock.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True, monitor='val_loss') history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary() file = open('./weights.txt', 'w') # 參數(shù)提取 for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close() loss = history.history['loss'] val_loss = history.history['val_loss'] plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() ################## predict ###################### # 測試集輸入模型進(jìn)行預(yù)測 predicted_stock_price = model.predict(x_test) # 對(duì)預(yù)測數(shù)據(jù)還原---從(0,1)反歸一化到原始范圍 predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 對(duì)真實(shí)數(shù)據(jù)還原---從(0,1)反歸一化到原始范圍 real_stock_price = sc.inverse_transform(test_set[60:]) # 畫出真實(shí)數(shù)據(jù)和預(yù)測數(shù)據(jù)的對(duì)比曲線 plt.plot(real_stock_price, color='red', label='MaoTai Stock Price') plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price') plt.title('MaoTai Stock Price Prediction') plt.xlabel('Time') plt.ylabel('MaoTai Stock Price') plt.legend() plt.show() ##########evaluate############## # calculate MSE 均方誤差 ---> E[(預(yù)測值-真實(shí)值)^2] (預(yù)測值減真實(shí)值求平方后求均值) mse = mean_squared_error(predicted_stock_price, real_stock_price) # calculate RMSE 均方根誤差--->sqrt[MSE] (對(duì)均方誤差開方) rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price)) # calculate MAE 平均絕對(duì)誤差----->E[|預(yù)測值-真實(shí)值|](預(yù)測值減真實(shí)值求絕對(duì)值后求均值) mae = mean_absolute_error(predicted_stock_price, real_stock_price) print('均方誤差: %.6f' % mse) print('均方根誤差: %.6f' % rmse) print('平均絕對(duì)誤差: %.6f' % mae)
以上就是Python+SimpleRNN實(shí)現(xiàn)股票預(yù)測詳解的詳細(xì)內(nèi)容,更多關(guān)于Python SimpleRNN股票預(yù)測的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
解析Mac OS下部署Pyhton的Django框架項(xiàng)目的過程
這篇文章主要介紹了Mac OS下部署Pyhton的Django框架項(xiàng)目的過程,還附帶將了一個(gè)gunicorn結(jié)合Nginx來部署Django應(yīng)用的方法,需要的朋友可以參考下2016-05-05基于PyQt5制作Excel數(shù)據(jù)分組匯總器
這篇文章主要介紹了基于PyQt5制作的一個(gè)小工具:Excel數(shù)據(jù)分組匯總器。文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起試一試2022-01-01你知道怎么改進(jìn)Python 二分法和牛頓迭代法求算術(shù)平方根嗎
這篇文章主要介紹了Python編程實(shí)現(xiàn)二分法和牛頓迭代法求平方根代碼的改進(jìn),具有一定參考價(jià)值,需要的朋友可以了解下,希望能夠給你帶來幫助2021-08-08python 進(jìn)程間數(shù)據(jù)共享multiProcess.Manger實(shí)現(xiàn)解析
這篇文章主要介紹了python 進(jìn)程間數(shù)據(jù)共享multiProcess.Manger實(shí)現(xiàn)解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09python異常的傳遞知識(shí)點(diǎn)總結(jié)
在本篇文章里小編給大家整理的是一篇關(guān)于python異常的傳遞知識(shí)點(diǎn)總結(jié),有興趣的朋友們可以學(xué)習(xí)下。2021-06-06python實(shí)現(xiàn)的分層隨機(jī)抽樣案例
這篇文章主要介紹了python實(shí)現(xiàn)的分層隨機(jī)抽樣案例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-02-02Python?Tkinter?Gui運(yùn)行不卡頓(解決多線程解決界面卡死問題)
最近寫的Python代碼不知為何,總是執(zhí)行到一半卡住不動(dòng),所以下面這篇文章主要給大家介紹了關(guān)于Python?Tkinter?Gui運(yùn)行不卡頓,解決多線程解決界面卡死問題的相關(guān)資料,需要的朋友可以參考下2023-02-02pandas使用fillna函數(shù)填充NaN值的代碼實(shí)例
最近在工作中遇到一個(gè)問題,pandas讀取的數(shù)據(jù)中nan在保存后變成空字符串,所以下面這篇文章主要給大家介紹了關(guān)于pandas使用fillna函數(shù)填充NaN值的相關(guān)資料,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-07-07Python 余弦相似度與皮爾遜相關(guān)系數(shù) 計(jì)算實(shí)例
今天小編就為大家分享一篇Python 余弦相似度與皮爾遜相關(guān)系數(shù) 計(jì)算實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-12-12