
XGBoost - 分位數迴歸
XGBoost 一次預測一個主要值,例如所有可能結果的平均值。有時,我們試圖理解每種可能性,包括最壞情況和最佳情況。這就是分位數迴歸的用途。
這就是使用分位數損失函式來訓練獨立 XGBoost 模型的方法。例如,您可以為 0.05、0.5 和 0.95 分位數訓練模型,以獲得預測區間上下限。
由於分位數迴歸,除了均值(平均值)之外,我們還可以預測資料中的其他點或“分位數”。例如:第 10 個百分位數(較差的結果)、第 50 個百分位數(平均結果)和第 90 個百分位數(可接受的結果)。
分位數迴歸如何在 XGBoost 中工作?
XGBoost 通常透過專注於平均值的預測來減少誤差。當我們將 XGBoost 與分位數迴歸結合使用時,我們調整誤差測量。我們不是關注總誤差,而是突出顯示特定分位數與預測之間的差距。
簡單來說,將分位數迴歸與 XGBoost 一起使用 -
它預測給定百分位數的值。
對於許多情況,我們可以計算可能的結果(糟糕的、平均的和好的)。
例如,在進行財務估計時,這在建立最佳和最壞情況策略時非常有用。
XGBoost 的分位數迴歸
我們將匯入必要的庫,藉助 XGBoost 建立分位數迴歸模型,以生成預測區間。
import xgboost as xgb import numpy as np import matplotlib.pyplot as plt
為了訓練和測試,目標值和特徵值是使用合成數據從隨機分佈中生成的。
# Generate synthetic data np.random.seed(42) X_train = np.random.rand(100, 10) y_train = np.random.rand(100) X_test = np.random.rand(20, 10)
為了計算 XGBoost 迴歸器的目標函式所需的梯度和 Hessian 矩陣,我們建立了一個自定義的分位數損失函式。使用三個不同的分位數來訓練模型 - 0.05、0.5(中位數)和 0.95。這些分位數分別對應於預測區間的下限、中位數和上限。訓練後,每個分位數都會對測試集進行預測。
def quantile_loss(quantile_value): def loss(true_values, predicted_values): error = true_values - predicted_values gradient = np.where(error > 0, quantile_value, quantile_value - 1) # Hessian is constant hessian = np.ones_like(error) return gradient, hessian return loss quantile_levels = [0.05, 0.5, 0.95] regression_models = {} for quantile in quantile_levels: regressor = xgb.XGBRegressor(objective=quantile_loss(quantile)) regressor.fit(X_train, y_train) regression_models[quantile] = regressor # Predicting quantiles predictions_05 = regression_models[0.05].predict(X_test) predictions_50 = regression_models[0.5].predict(X_test) predictions_95 = regression_models[0.95].predict(X_test) # Lower and upper bounds lower_prediction = predictions_05 upper_prediction = predictions_95 median_prediction = predictions_50
透過繪製中位數預測並填充上下限之間的差距,我們可以看到資料並有效地顯示中位數預測周圍的預測區間。
# Visualization plt.figure(figsize=(10, 6)) plt.plot(median_prediction, label='Median Prediction', color='green') plt.fill_between(range(len(median_prediction)), lower_prediction, upper_prediction, color='lightcoral', alpha=0.5, label='Prediction Interval') plt.title('Quantile Regression Prediction Interval') plt.xlabel('Test Data Points') plt.ylabel('Predictions') plt.legend() plt.show()
輸出
以下是上述模型的結果 -

廣告