XGBoost - 分位数回归



XGBoost 一次预测一个主要值,例如所有可能结果的平均值。有时,我们试图理解每种可能性,包括最坏情况和最佳情况。这就是分位数回归的用途。

这就是使用分位数损失函数来训练独立 XGBoost 模型的方法。例如,您可以为 0.05、0.5 和 0.95 分位数训练模型,以获得预测区间上下限。

由于分位数回归,除了均值(平均值)之外,我们还可以预测数据中的其他点或“分位数”。例如:第 10 个百分位数(较差的结果)、第 50 个百分位数(平均结果)和第 90 个百分位数(可接受的结果)。

分位数回归如何在 XGBoost 中工作?

XGBoost 通常通过专注于平均值的预测来减少误差。当我们将 XGBoost 与分位数回归结合使用时,我们调整误差测量。我们不是关注总误差,而是突出显示特定分位数与预测之间的差距。

简单来说,将分位数回归与 XGBoost 一起使用 -

  • 它预测给定百分位数的值。

  • 对于许多情况,我们可以计算可能的结果(糟糕的、平均的和好的)。

例如,在进行财务估计时,这在创建最佳和最坏情况策略时非常有用。

XGBoost 的分位数回归

我们将导入必要的库,借助 XGBoost 建立分位数回归模型,以生成预测区间。

import xgboost as xgb
import numpy as np
import matplotlib.pyplot as plt

为了训练和测试,目标值和特征值是使用合成数据从随机分布中生成的。

# Generate synthetic data
np.random.seed(42)
X_train = np.random.rand(100, 10)
y_train = np.random.rand(100)
X_test = np.random.rand(20, 10)

为了计算 XGBoost 回归器的目标函数所需的梯度和 Hessian 矩阵,我们创建了一个自定义的分位数损失函数。使用三个不同的分位数来训练模型 - 0.05、0.5(中位数)和 0.95。这些分位数分别对应于预测区间的下限、中位数和上限。训练后,每个分位数都会对测试集进行预测。

def quantile_loss(quantile_value):
   def loss(true_values, predicted_values):
      error = true_values - predicted_values
      gradient = np.where(error > 0, quantile_value, quantile_value - 1)
      # Hessian is constant
      hessian = np.ones_like(error)  
      return gradient, hessian
   return loss

quantile_levels = [0.05, 0.5, 0.95]
regression_models = {}

for quantile in quantile_levels:
   regressor = xgb.XGBRegressor(objective=quantile_loss(quantile))
   regressor.fit(X_train, y_train)
   regression_models[quantile] = regressor

# Predicting quantiles
predictions_05 = regression_models[0.05].predict(X_test)
predictions_50 = regression_models[0.5].predict(X_test)
predictions_95 = regression_models[0.95].predict(X_test)

# Lower and upper bounds
lower_prediction = predictions_05
upper_prediction = predictions_95
median_prediction = predictions_50

通过绘制中位数预测并填充上下限之间的差距,我们可以看到数据并有效地显示中位数预测周围的预测区间。

# Visualization
plt.figure(figsize=(10, 6))
plt.plot(median_prediction, label='Median Prediction', color='green')  
plt.fill_between(range(len(median_prediction)), lower_prediction, upper_prediction, color='lightcoral', alpha=0.5, label='Prediction Interval') 
plt.title('Quantile Regression Prediction Interval')
plt.xlabel('Test Data Points')
plt.ylabel('Predictions')
plt.legend()
plt.show()

输出

以下是上述模型的结果 -

Plotting the Median Prediction
广告
© . All rights reserved.