[关闭]
@w568w 2020-08-04T07:24:43.000000Z 字数 2671 阅读 1383

笔记:基于LSTM(Long Short-Term Memory)的股票预测

网络结构

采用LSTM*3 + Full-connected为网络。

Created with Raphaël 2.1.2Input shape [16,4]LSTM [16,4]->128LSTM [16,4]->128LSTM [16,4]->128Dense [128]->1Output shape [1]

每次输入16天数据,输出接下来1天的收盘价格。

训练时使用Keras作为框架,数据均已归一化

训练过程

600001(浦发银行)近14年数据为训练集和验证集

使用Geforce GTX 1650显卡、CUDA 10.1环境,在上轮流训练,总用时约 20 分钟。

训练结果

经过若干轮训练,训练集上的Loss稳定在左右,测试集上的Loss稳定在左右,符合精度要求。

在随机抽取的两支股票中获取近14年数据,抽样验证,结果如下:

Figure 1: 000998(隆平高科)

Figure 1:000998(隆平高科)

此处输入图片的描述

Figure 2: 600435(北方导航)

代码

  1. import os
  2. from typing import List, Any
  3. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
  4. from numpy.core._multiarray_umath import ndarray
  5. from tensorflow.keras.layers import Dense
  6. import csv
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. from tensorflow.python.keras import Sequential
  10. from tensorflow.python.keras.layers import LSTM, Dropout
  11. def train(m: Sequential):
  12. for i in range(0, 5):
  13. m.fit(Train_X, Train_Y, batch_size=BATCH_SIZE, epochs=EPOCHES,
  14. validation_split=1 / 3)
  15. m.fit(Train_X, Train_Y, batch_size=BATCH_SIZE, epochs=EPOCHES)
  16. m.save_weights("lstm.mw")
  17. def validate(m: Sequential, num=50):
  18. plt.title("Validate Result")
  19. plt.plot(range(1, num + 1), m.predict(Train_X[0:num]).flatten(order='C'), 'blue',
  20. label='Prediction')
  21. plt.plot(range(1, num + 1), Train_Y[0:num].flatten(order='C'), 'red', label='Real Value')
  22. plt.legend()
  23. plt.show()
  24. def normalization(arr: ndarray):
  25. return (arr - arr.min()) / (arr.max() - arr.min())
  26. BATCH_SIZE = 5
  27. TIME_STEP = 16
  28. DATA_FRAME_NUM = 4
  29. OUTPUT_NUM = 1
  30. EPOCHES = 5
  31. Train_X: ndarray = np.empty([1, 3, 2], dtype=float)
  32. Train_Y = np.empty([], dtype=float)
  33. with open('history_A_stock_k_test_data.csv', 'r') as f:
  34. data_list: List[Any] = list(csv.reader(f))
  35. Train_X = np.empty([0, TIME_STEP, DATA_FRAME_NUM], dtype=float)
  36. Train_Y = np.empty([0, OUTPUT_NUM], dtype=float)
  37. time_step_array: ndarray = np.empty([1, TIME_STEP, DATA_FRAME_NUM], dtype=float)
  38. for i in range(1, len(data_list)):
  39. time_step_array[0][i % TIME_STEP - 1] = data_list[i][1:5]
  40. if i < TIME_STEP:
  41. continue
  42. if i % TIME_STEP == 0:
  43. Train_X = np.append(Train_X, time_step_array, axis=0)
  44. elif i % TIME_STEP == 1:
  45. Train_Y = np.append(Train_Y,
  46. np.expand_dims(np.asarray(data_list[i][4:5], dtype=float), axis=0),
  47. axis=0)
  48. Train_X = normalization(Train_X)
  49. Train_Y = normalization(Train_Y)
  50. model: Sequential = Sequential()
  51. model.add(LSTM(128, return_sequences=True,
  52. input_shape=(TIME_STEP, DATA_FRAME_NUM)))
  53. model.add(LSTM(128, return_sequences=True))
  54. model.add(LSTM(128))
  55. model.add(Dropout(0.2))
  56. model.add(Dense(OUTPUT_NUM, activation="linear"))
  57. model.compile(loss='mse', optimizer='rmsprop')
  58. model.load_weights("lstm.mw")
  59. # 训练他妈的股票预测模型
  60. train(model)

反思

经过又一次将近 30 分钟的训练,达到了如下精度:

此处输入图片的描述

Figure 3: 300059(东方财富)


事实证明,在一定误差范围、一定时期内,股票的走向或许是可预测的。

当然这个模型非常烂,没有考虑外界因素(政策变化、市场走向等等),并且只是简单运用了LSTM的预测能力。

PS:就我所知,股票预测是很多大学人工智能课程的入门教学示范...

这样的模型有很多人做过,我不是第一个,也肯定不是最后一个。

尽管我不懂炒股,但就我浅薄的经济学知识来看,股票真正难搞的不是走向,而是常常出现的、猝不及防的黑天鹅事件。而这一点,在相当程度内,都是机器学习无法预知的。

真指望用这个炒股还是算了...

0202年了,不会还有人炒股吧 不会吧不会吧

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注