Python+SimpleRNN實(shí)現(xiàn)股票預(yù)測(cè)詳解
原理請(qǐng)查看前面幾篇文章。
1、數(shù)據(jù)源
SH600519.csv 是用 tushare 模塊下載的 SH600519 貴州茅臺(tái)的日 k 線數(shù)據(jù),本次例子中只用它的 C 列數(shù)據(jù)(如圖 所示):
用連續(xù) 60 天的開(kāi)盤(pán)價(jià),預(yù)測(cè)第 61 天的開(kāi)盤(pán)價(jià)。
2、代碼實(shí)現(xiàn)
按照六步法: import 相關(guān)模塊->讀取貴州茅臺(tái)日 k 線數(shù)據(jù)到變量 maotai,把變量 maotai 中前 2126 天數(shù)據(jù)中的開(kāi)盤(pán)價(jià)作為訓(xùn)練數(shù)據(jù),把變量 maotai 中后 300 天數(shù)據(jù)中的開(kāi)盤(pán)價(jià)作為測(cè)試數(shù)據(jù);然后對(duì)開(kāi)盤(pán)價(jià)進(jìn)行歸一化,使送入神經(jīng)網(wǎng)絡(luò)的數(shù)據(jù)分布在 0 到 1 之間;
接下來(lái)建立空列表分別用于接收訓(xùn)練集輸入特征、訓(xùn)練集標(biāo)簽、測(cè)試集輸入特征、測(cè)試集標(biāo)簽;
繼續(xù)構(gòu)造數(shù)據(jù)。用 for 循環(huán)遍歷整個(gè)訓(xùn)練數(shù)據(jù),每連續(xù)60 天數(shù)據(jù)作為輸入特征 x_train,第 61 天數(shù)據(jù)作為對(duì)應(yīng)的標(biāo)簽 y_train ,一共生成 2066 組訓(xùn)練數(shù)據(jù),然后打亂訓(xùn)練數(shù)據(jù)的順序并轉(zhuǎn)變?yōu)?array 格式繼而轉(zhuǎn)變?yōu)?RNN 輸入要求的維度;
同理,利用 for 循環(huán)遍歷整個(gè)測(cè)試數(shù)據(jù),一共生成 240組測(cè)試數(shù)據(jù),測(cè)試集不需要打亂順序,但需轉(zhuǎn)變?yōu)?array 格式繼而轉(zhuǎn)變?yōu)?RNN 輸入要求的維度。
用 sequntial 搭建神經(jīng)網(wǎng)絡(luò):
第一層循環(huán)計(jì)算層記憶體設(shè)定 80 個(gè),每個(gè)時(shí)間步推送 h t h_t ht?給下一層,使用 0.2 的 Dropout;
第二層循環(huán)計(jì)算層設(shè)定記憶體有 100 個(gè),僅最后的時(shí)間步推送 h t h_t ht?給下一層,使用 0.2 的 Dropout;
由于輸出值是第 61 天的開(kāi)盤(pán)價(jià)只有一個(gè)數(shù),所以全連接 Dense 是 1->compile 配置訓(xùn)練方法使用 adam 優(yōu)化器,使用均方誤差損失函數(shù)。在股票預(yù)測(cè)代碼中,只需觀測(cè) loss,訓(xùn)練迭代打印的時(shí)候也只打印 loss,所以這里就無(wú)需給metrics賦值->設(shè)置斷點(diǎn)續(xù)訓(xùn),fit 執(zhí)行訓(xùn)練過(guò)程->summary 打印出網(wǎng)絡(luò)結(jié)構(gòu)和參數(shù)統(tǒng)計(jì)。
進(jìn)行 loss 可視化與參數(shù)報(bào)錯(cuò)操作
進(jìn)行股票預(yù)測(cè)。用 predict 預(yù)測(cè)測(cè)試集數(shù)據(jù),然后將預(yù)測(cè)值和真實(shí)值從歸一化的數(shù)值變換到真實(shí)數(shù)值,最后用紅色線畫(huà)出真實(shí)值曲線 、用藍(lán)色線畫(huà)出預(yù)測(cè)值曲線。
為了評(píng)價(jià)模型優(yōu)劣,給出了三個(gè)評(píng)判指標(biāo):均方誤差、均方根誤差和平均絕對(duì)誤差,這些誤差越小說(shuō)明預(yù)測(cè)的數(shù)值與真實(shí)值越接近。
RNN 股票預(yù)測(cè) loss 曲線:
RNN 股票預(yù)測(cè)曲線:
RNN 股票預(yù)測(cè)評(píng)價(jià)指標(biāo):
模型摘要:
3、完整代碼
import numpy as np import tensorflow as tf from tensorflow.keras.layers import Dropout, Dense, SimpleRNN import matplotlib.pyplot as plt import os import pandas as pd from sklearn.preprocessing import MinMaxScaler from sklearn.metrics import mean_squared_error, mean_absolute_error import math # 讀取股票文件 maotai = pd.read_csv('./SH600519.csv') # 前(2426-300=2126)天的開(kāi)盤(pán)價(jià)作為訓(xùn)練集,表格從0開(kāi)始計(jì)數(shù),2:3 是提取[2:3)列,前閉后開(kāi),故提取出C列開(kāi)盤(pán)價(jià) training_set = maotai.iloc[0:2426 - 300, 2:3].values # 后300天的開(kāi)盤(pán)價(jià)作為測(cè)試集 test_set = maotai.iloc[2426 - 300:, 2:3].values # 歸一化 sc = MinMaxScaler(feature_range=(0, 1)) # 定義歸一化:歸一化到(0,1)之間 training_set_scaled = sc.fit_transform(training_set) # 求得訓(xùn)練集的最大值,最小值這些訓(xùn)練集固有的屬性,并在訓(xùn)練集上進(jìn)行歸一化 test_set = sc.transform(test_set) # 利用訓(xùn)練集的屬性對(duì)測(cè)試集進(jìn)行歸一化 x_train = [] y_train = [] x_test = [] y_test = [] # 測(cè)試集:csv表格中前2426-300=2126天數(shù)據(jù) # 利用for循環(huán),遍歷整個(gè)訓(xùn)練集,提取訓(xùn)練集中連續(xù)60天的開(kāi)盤(pán)價(jià)作為輸入特征x_train,第61天的數(shù)據(jù)作為標(biāo)簽,for循環(huán)共構(gòu)建2426-300-60=2066組數(shù)據(jù)。 for i in range(60, len(training_set_scaled)): x_train.append(training_set_scaled[i - 60:i, 0]) y_train.append(training_set_scaled[i, 0]) # 對(duì)訓(xùn)練集進(jìn)行打亂 np.random.seed(7) np.random.shuffle(x_train) np.random.seed(7) np.random.shuffle(y_train) tf.random.set_seed(7) # 將訓(xùn)練集由list格式變?yōu)閍rray格式 x_train, y_train = np.array(x_train), np.array(y_train) # 使x_train符合RNN輸入要求:[送入樣本數(shù), 循環(huán)核時(shí)間展開(kāi)步數(shù), 每個(gè)時(shí)間步輸入特征個(gè)數(shù)]。 # 此處整個(gè)數(shù)據(jù)集送入,送入樣本數(shù)為x_train.shape[0]即2066組數(shù)據(jù);輸入60個(gè)開(kāi)盤(pán)價(jià),預(yù)測(cè)出第61天的開(kāi)盤(pán)價(jià),循環(huán)核時(shí)間展開(kāi)步數(shù)為60; 每個(gè)時(shí)間步送入的特征是某一天的開(kāi)盤(pán)價(jià),只有1個(gè)數(shù)據(jù),故每個(gè)時(shí)間步輸入特征個(gè)數(shù)為1 x_train = np.reshape(x_train, (x_train.shape[0], 60, 1)) # 測(cè)試集:csv表格中后300天數(shù)據(jù) # 利用for循環(huán),遍歷整個(gè)測(cè)試集,提取測(cè)試集中連續(xù)60天的開(kāi)盤(pán)價(jià)作為輸入特征x_test,第61天的數(shù)據(jù)作為標(biāo)簽y_test,for循環(huán)共構(gòu)建300-60=240組數(shù)據(jù)。 for i in range(60, len(test_set)): x_test.append(test_set[i - 60:i, 0]) y_test.append(test_set[i, 0]) # 測(cè)試集變array并reshape為符合RNN輸入要求:[送入樣本數(shù), 循環(huán)核時(shí)間展開(kāi)步數(shù), 每個(gè)時(shí)間步輸入特征個(gè)數(shù)] x_test, y_test = np.array(x_test), np.array(y_test) x_test = np.reshape(x_test, (x_test.shape[0], 60, 1)) model = tf.keras.Sequential([ SimpleRNN(80, return_sequences=True),# 第一層循環(huán)計(jì)算層:記憶體設(shè)定80個(gè),每個(gè)時(shí)間步推送ht給下一層 Dropout(0.2), #使用0.2的Dropout SimpleRNN(100),# 第二層循環(huán)計(jì)算層,設(shè)定記憶體100個(gè) Dropout(0.2), # Dense(1) # 由于輸出值是第61天的開(kāi)盤(pán)價(jià),只有一個(gè)數(shù),所以Dense是1 ]) model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss='mean_squared_error') # 損失函數(shù)用均方誤差 # 該應(yīng)用只觀測(cè)loss數(shù)值,不觀測(cè)準(zhǔn)確率,所以刪去metrics選項(xiàng),一會(huì)在每個(gè)epoch迭代顯示時(shí)只顯示loss值 checkpoint_save_path = "./checkpoint/rnn_stock.ckpt" if os.path.exists(checkpoint_save_path + '.index'): print('-------------load the model-----------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True, monitor='val_loss') history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) model.summary() file = open('./weights.txt', 'w') # 參數(shù)提取 for v in model.trainable_variables: file.write(str(v.name) + '\n') file.write(str(v.shape) + '\n') file.write(str(v.numpy()) + '\n') file.close() loss = history.history['loss'] val_loss = history.history['val_loss'] plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() ################## predict ###################### # 測(cè)試集輸入模型進(jìn)行預(yù)測(cè) predicted_stock_price = model.predict(x_test) # 對(duì)預(yù)測(cè)數(shù)據(jù)還原---從(0,1)反歸一化到原始范圍 predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 對(duì)真實(shí)數(shù)據(jù)還原---從(0,1)反歸一化到原始范圍 real_stock_price = sc.inverse_transform(test_set[60:]) # 畫(huà)出真實(shí)數(shù)據(jù)和預(yù)測(cè)數(shù)據(jù)的對(duì)比曲線 plt.plot(real_stock_price, color='red', label='MaoTai Stock Price') plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price') plt.title('MaoTai Stock Price Prediction') plt.xlabel('Time') plt.ylabel('MaoTai Stock Price') plt.legend() plt.show() ##########evaluate############## # calculate MSE 均方誤差 ---> E[(預(yù)測(cè)值-真實(shí)值)^2] (預(yù)測(cè)值減真實(shí)值求平方后求均值) mse = mean_squared_error(predicted_stock_price, real_stock_price) # calculate RMSE 均方根誤差--->sqrt[MSE] (對(duì)均方誤差開(kāi)方) rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price)) # calculate MAE 平均絕對(duì)誤差----->E[|預(yù)測(cè)值-真實(shí)值|](預(yù)測(cè)值減真實(shí)值求絕對(duì)值后求均值) mae = mean_absolute_error(predicted_stock_price, real_stock_price) print('均方誤差: %.6f' % mse) print('均方根誤差: %.6f' % rmse) print('平均絕對(duì)誤差: %.6f' % mae)
以上就是Python+SimpleRNN實(shí)現(xiàn)股票預(yù)測(cè)詳解的詳細(xì)內(nèi)容,更多關(guān)于Python SimpleRNN股票預(yù)測(cè)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
解析Mac OS下部署Pyhton的Django框架項(xiàng)目的過(guò)程
這篇文章主要介紹了Mac OS下部署Pyhton的Django框架項(xiàng)目的過(guò)程,還附帶將了一個(gè)gunicorn結(jié)合Nginx來(lái)部署Django應(yīng)用的方法,需要的朋友可以參考下2016-05-05Python爬蟲(chóng)番外篇之Cookie和Session詳解
這篇文章主要介紹了Python爬蟲(chóng)番外篇之Cookie和Session詳解,具有一定借鑒價(jià)值,需要的朋友可以參考下2017-12-12基于PyQt5制作Excel數(shù)據(jù)分組匯總器
這篇文章主要介紹了基于PyQt5制作的一個(gè)小工具:Excel數(shù)據(jù)分組匯總器。文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起試一試2022-01-01你知道怎么改進(jìn)Python 二分法和牛頓迭代法求算術(shù)平方根嗎
這篇文章主要介紹了Python編程實(shí)現(xiàn)二分法和牛頓迭代法求平方根代碼的改進(jìn),具有一定參考價(jià)值,需要的朋友可以了解下,希望能夠給你帶來(lái)幫助2021-08-08python 進(jìn)程間數(shù)據(jù)共享multiProcess.Manger實(shí)現(xiàn)解析
這篇文章主要介紹了python 進(jìn)程間數(shù)據(jù)共享multiProcess.Manger實(shí)現(xiàn)解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09python異常的傳遞知識(shí)點(diǎn)總結(jié)
在本篇文章里小編給大家整理的是一篇關(guān)于python異常的傳遞知識(shí)點(diǎn)總結(jié),有興趣的朋友們可以學(xué)習(xí)下。2021-06-06python實(shí)現(xiàn)的分層隨機(jī)抽樣案例
這篇文章主要介紹了python實(shí)現(xiàn)的分層隨機(jī)抽樣案例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02Python?Tkinter?Gui運(yùn)行不卡頓(解決多線程解決界面卡死問(wèn)題)
最近寫(xiě)的Python代碼不知為何,總是執(zhí)行到一半卡住不動(dòng),所以下面這篇文章主要給大家介紹了關(guān)于Python?Tkinter?Gui運(yùn)行不卡頓,解決多線程解決界面卡死問(wèn)題的相關(guān)資料,需要的朋友可以參考下2023-02-02pandas使用fillna函數(shù)填充N(xiāo)aN值的代碼實(shí)例
最近在工作中遇到一個(gè)問(wèn)題,pandas讀取的數(shù)據(jù)中nan在保存后變成空字符串,所以下面這篇文章主要給大家介紹了關(guān)于pandas使用fillna函數(shù)填充N(xiāo)aN值的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-07-07Python 余弦相似度與皮爾遜相關(guān)系數(shù) 計(jì)算實(shí)例
今天小編就為大家分享一篇Python 余弦相似度與皮爾遜相關(guān)系數(shù) 計(jì)算實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-12-12