使用Python的ARIMA模型进行预测
ARIMA是一种用于时间序列预测的统计模型,它结合了三个组成部分:自回归 (AR)、差分积分 (I) 和移动平均 (MA)。
自回归 (AR) − 此组件模拟观测值与多个滞后观测值之间的依赖关系。它基于这样的思想:时间序列的过去值可用于预测未来值。自回归的阶数,用“p”表示,指定要使用多少个滞后观测值作为预测因子。
差分积分 (I) − 此组件通过去除趋势和季节性来处理时间序列数据的非平稳性。积分阶数,用“d”表示,是指原始时间序列数据需要差分多少次才能使其平稳,即消除趋势和季节性。
移动平均 (MA) − 此组件模拟应用 AR 和 I 组件后时间序列残差误差之间的依赖关系。移动平均的阶数,用“q”表示,指定要使用多少个滞后残差误差作为预测因子。
ARIMA 模型的一般形式为 ARIMA (p, d, q),其中 p、d 和 q 分别是自回归、差分积分和移动平均的阶数。要将 ARIMA 模型用于预测,首先必须确定最适合数据的 p、d 和 q 值。这可以通过称为模型选择的流程来完成,该流程涉及拟合具有不同 p、d 和 q 组合的各种 ARIMA 模型,并选择误差最小的模型。
预测未来12个月的销售额
使用 ARIMA 预测销售额是使用统计技术根据公司的历史销售数据预测公司未来销售额的过程。此过程通常包含以下步骤
收集历史销售数据并将其转换为时间序列格式。
可视化数据以识别任何趋势、季节性或模式。
确定使时间序列平稳所需的差分阶数。
根据数据中的模式选择 ARIMA 模型的阶数 (p, d, q)。
将 ARIMA 模型拟合到数据并对未来销售额进行预测。
评估模型的性能并根据需要进行调整。
使用模型预测未来销售额并根据预测做出决策。
ARIMA 是销售预测中的一种常用方法,因为它可以捕获数据中的复杂模式并处理时间序列中的趋势和季节性。但是,模型的性能可能会受到各种因素的影响,例如数据的质量、参数的选择以及模型捕获数据中潜在模式的能力。
现在让我们来看一个使用 ARIMA 进行预测的示例。
下面使用的 dataset (sales_data.csv) 可在此处获取。
示例
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
# Load the time series data
data = pd.read_csv('sales_data.csv')
# Fit the ARIMA model
model = sm.tsa.ARIMA(data['sales'], order=(2, 1, 1))
model_fit = model.fit()
# Forecast future values
forecast = model_fit.forecast(steps=12)
# Print the forecast
print(forecast[0])
# Plot the time series
data2=np.append(data,forecast[0])
plt.plot(data2)
plt.xlabel('Date')
plt.ylabel('Sales')
plt.title('Synthetic Time Series Data')
plt.show()
输出
[56.29545598 56.60345925 56.90298063 57.19449608 57.47839568 57.7550522 58.02482013 58.28803659 58.54502221 58.79608193 59.04150576 59.28156952]
在此示例中,时间序列数据是特定产品的销售数据,从 CSV 文件加载到 pandas 数据框中。使用 sm.tsa.ARIMA 函数将 ARIMA 模型拟合到销售数据,并将自回归的阶数设置为 2,差分积分的阶数设置为 1,移动平均的阶数设置为 1。
然后使用 model_fit 对象使用 forecast 方法生成未来销售额的预测,其中 steps 参数为 12,以指定要预测的未来值的个数。然后打印预测结果,其中给出了未来 12 个月的预期销售额。
自定义数据集
在本例中,我们将直接在代码中定义数据集。数据最初将以列表的形式存在,稍后将转换为 Pandas 数据框。
此代码随后将 ARIMA 模型拟合到自定义数据集,对接下来的 12 个时间步长进行预测,并将预测结果存储在 predictions 变量中。在此示例中,自定义数据集是包含 12 个值的列表,但是对于任何时间序列数据,拟合 ARIMA 模型和进行预测的过程都是相同的。
示例
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
# Load custom dataset
data = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
# Convert data to a pandas DataFrame
df = pd.DataFrame({'values': data})
# Fit the ARIMA model
model = ARIMA(df['values'], order=(1,0,0))
model_fit = model.fit()
# Make predictions
predictions = model_fit.forecast(steps=12)
print(predictions)
# Plot the original dataset and predictions
plt.plot(df['values'], label='Original Data')
plt.plot(predictions, label='Predictions')
plt.legend()
plt.show()
输出
12 118.967858 13 117.955086 14 116.961320 15 115.986203 16 115.029385 17 114.090523 18 113.169280 19 112.265326 20 111.378335 21 110.507989 22 109.653977 23 108.815991 Name: predicted_mean, dtype: float64
波士顿房价数据集
import numpy as np
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
import warnings
warnings.filterwarnings("ignore")
# Load the Boston dataset
boston = load_boston()
data = boston.data
# Convert data to a pandas DataFrame
df = pd.DataFrame(data, columns=boston.feature_names)
df=df.head(20)
# Fit the ARIMA model
model = ARIMA(df['CRIM'], order=(1,0,0))
model_fit = model.fit()
# Make predictions
predictions = model_fit.forecast(steps=12)
print(predictions.tolist())
# Plot the original dataset and predictions
plt.plot(df['CRIM'], label='Original Data')
plt.plot(predictions, label='Predictions')
plt.legend()
plt.show()
输出
[0.6738187961066762, 0.6288621548198372, 0.5899808007068923, 0.5563537401796019, 0.5272709259231514, 0.5021182639951554, 0.4803646470141665, 0.46155073963886595, 0.44527927953934654, 0.4312066890620576, 0.41903582046573945, 0.40850968154143097]
图表中的所有 X 值均为索引值。
结论
ARIMA 是一种强大的时间序列预测方法,可用于在 Python 中预测股票价格。使用 ARIMA 进行预测的过程包括将时间序列数据转换为平稳格式,确定差分、自回归和移动平均项的阶数,将 ARIMA 模型拟合到数据,生成预测并评估模型的性能。Python 中的 statsmodels 库提供了一种方便且高效的方法来执行 ARIMA 预测。但是,必须记住,ARIMA 只是众多可用于股票价格预测的方法之一,并且模型的结果可能会因所用数据的质量和特性而异。
数据结构
网络
关系数据库管理系统(RDBMS)
操作系统
Java
iOS
HTML
CSS
Android
Python
C语言编程
C++
C#
MongoDB
MySQL
Javascript
PHP