XGBoost - 分位数回归
XGBoost 一次预测一个主要值,例如所有可能结果的平均值。有时,我们试图理解每种可能性,包括最坏情况和最佳情况。这就是分位数回归的用途。
这就是使用分位数损失函数来训练独立 XGBoost 模型的方法。例如,您可以为 0.05、0.5 和 0.95 分位数训练模型,以获得预测区间上下限。
由于分位数回归,除了均值(平均值)之外,我们还可以预测数据中的其他点或“分位数”。例如:第 10 个百分位数(较差的结果)、第 50 个百分位数(平均结果)和第 90 个百分位数(可接受的结果)。
分位数回归如何在 XGBoost 中工作?
XGBoost 通常通过专注于平均值的预测来减少误差。当我们将 XGBoost 与分位数回归结合使用时,我们调整误差测量。我们不是关注总误差,而是突出显示特定分位数与预测之间的差距。
简单来说,将分位数回归与 XGBoost 一起使用 -
它预测给定百分位数的值。
对于许多情况,我们可以计算可能的结果(糟糕的、平均的和好的)。
例如,在进行财务估计时,这在创建最佳和最坏情况策略时非常有用。
XGBoost 的分位数回归
我们将导入必要的库,借助 XGBoost 建立分位数回归模型,以生成预测区间。
import xgboost as xgb import numpy as np import matplotlib.pyplot as plt
为了训练和测试,目标值和特征值是使用合成数据从随机分布中生成的。
# Generate synthetic data np.random.seed(42) X_train = np.random.rand(100, 10) y_train = np.random.rand(100) X_test = np.random.rand(20, 10)
为了计算 XGBoost 回归器的目标函数所需的梯度和 Hessian 矩阵,我们创建了一个自定义的分位数损失函数。使用三个不同的分位数来训练模型 - 0.05、0.5(中位数)和 0.95。这些分位数分别对应于预测区间的下限、中位数和上限。训练后,每个分位数都会对测试集进行预测。
def quantile_loss(quantile_value):
def loss(true_values, predicted_values):
error = true_values - predicted_values
gradient = np.where(error > 0, quantile_value, quantile_value - 1)
# Hessian is constant
hessian = np.ones_like(error)
return gradient, hessian
return loss
quantile_levels = [0.05, 0.5, 0.95]
regression_models = {}
for quantile in quantile_levels:
regressor = xgb.XGBRegressor(objective=quantile_loss(quantile))
regressor.fit(X_train, y_train)
regression_models[quantile] = regressor
# Predicting quantiles
predictions_05 = regression_models[0.05].predict(X_test)
predictions_50 = regression_models[0.5].predict(X_test)
predictions_95 = regression_models[0.95].predict(X_test)
# Lower and upper bounds
lower_prediction = predictions_05
upper_prediction = predictions_95
median_prediction = predictions_50
通过绘制中位数预测并填充上下限之间的差距,我们可以看到数据并有效地显示中位数预测周围的预测区间。
# Visualization
plt.figure(figsize=(10, 6))
plt.plot(median_prediction, label='Median Prediction', color='green')
plt.fill_between(range(len(median_prediction)), lower_prediction, upper_prediction, color='lightcoral', alpha=0.5, label='Prediction Interval')
plt.title('Quantile Regression Prediction Interval')
plt.xlabel('Test Data Points')
plt.ylabel('Predictions')
plt.legend()
plt.show()
输出
以下是上述模型的结果 -
广告