LightGBM - 回归



流行的机器学习方法 LightGBM(轻量级梯度提升机)用于回归和分类应用。当用于回归时,它会创建一系列决策树,每棵树都试图通过减少前一棵树的误差来最小化损失函数(例如均方误差)。

LightGBM 如何用于回归?

LightGBM 的基础,梯度提升,按顺序依次创建多个决策树。每棵树都努力纠正前一棵树所犯的错误。

与其他按层增长树的提升算法不同,LightGBM 按叶节点增长树。这意味着在扩展模型时,它会优化损失减少(即,最能改进模型的叶子节点)。这会产生更深、更准确的树,但需要仔细调整以避免过拟合。

为了减少预期结果和实际结果之间的差异,LightGBM 使用两种类型的回归任务损失函数——均方误差 (MSE) 和平均绝对误差 (MAE)。

何时使用 LightGBM 回归

以下是一些可以使用 LightGBM 进行回归的情况:

  • 当给定大型数据集时。

  • 当需要快速高效的模型时。

  • 当您的数据包含大量特征(列)或缺失值时。

使用 LightGBM 进行回归的示例

现在让我们看看如何创建一个 LightGBM 回归模型。这些步骤将帮助您了解该过程的每个步骤是如何工作的。

步骤 1 - 安装所需的库

在开始之前,请确保您已安装必要的库。Scikit-learn 用于数据处理,lightgbm 用于 LightGBM 模型。

pip install pandas scikit-learn lightgbm

步骤 2 - 加载数据

首先,使用 pandas 加载数据集。此数据集包含与健康相关的资料,包括年龄、性别、BMI、子女数量、居住地、吸烟状况和医疗费用。

import pandas as pd

# Load the dataset from your local file path
data = pd.read_csv('/My Docs/Python/medical_cost.csv')

# Display the first few rows of the dataset
print(data.head())

输出

这将产生以下结果:

   age     sex     bmi  children smoker     region      charges
0   19  female  27.900         0    yes  southwest  16884.92400
1   18    male  33.770         1     no  southeast   1725.55230
2   28    male  33.000         3     no  southeast   4449.46200
3   33    male  22.705         0     no  northwest  21984.47061
4   32    male  28.880         0     no  northwest   3866.85520

步骤 3 - 分离特征和目标变量

现在正在分离目标变量 (y) 和特征 (X)。在本例中,我们希望使用其他特征来预测“费用”列。

# 'charges' is the target column that we want to predict
# All columns except 'charges' are features
X = data.drop('charges', axis=1)  

# The 'charges' column is the target variable
y = data['charges']  

步骤 4 - 处理分类数据

数据集中的分类特征(性别、吸烟者和地区)需要转换为数值格式,因为 LightGBM 使用数值数据。独热编码用于将这些类别列转换为二进制格式(0 和 1)。

# Convert categorical variables to numerical 
X = pd.get_dummies(X, drop_first=True)

这里:

  • pd.get_dummies() 用于为每个类别生成额外的二进制列。

  • drop_first=True 通过消除每个分类变量的第一个类别来避免多重共线性。

步骤 5 - 分割数据

为了了解模型的性能,我们将数据分成两组——训练集(占数据的 80%)和测试集(占数据的 20%)。

  • train_test_split() 用于随机分割数据,同时保持给定的比例 (test_size=0.2)。

  • 使用 random_state = 42 可以确保分割结果可重现。

from sklearn.model_selection import train_test_split

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

步骤 6:初始化 LightGBM 回归器

现在我们将为回归初始化 LightGBM 模型。LGBMRegressor 是 LightGBM 的实现,专门用于回归任务。LGBMRegressor 模型非常高效和灵活,可以有效地处理大型数据集。

from lightgbm import LGBMRegressor

# Initialize the LightGBM regressor model
model = LGBMRegressor()

步骤 7:训练模型

接下来,我们将使用训练数据 (X_train 和 y_train) 来训练模型。这里使用 fit() 方法通过查找训练数据中的模式并预测目标变量(费用)来训练模型。

# Train the model on the training data
model.fit(X_train, y_train)

输出

运行上述代码后,我们将得到以下结果:

[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001000 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 319
[LightGBM] [Info] Number of data points in the train set: 1070, number of used features: 8
[LightGBM] [Info] Start training from score 13346.089733

 LGBMRegressori
LGBMRegressor()

步骤 8:进行预测

训练后,我们使用模型对测试集 (X_test) 进行预测。model.predict(X_test) 根据从训练数据中学习到的模式生成测试集的预测值。

# Predict on the test set
y_pred = model.predict(X_test)

步骤 9:评估模型

我们将使用均方误差 (MSE) 来衡量模型的性能,这是一个常用的回归指标。均方误差或 MSE 计算的是预期值和实际值之间的差异的平方平均值。较低的 MSE 值表示更好的性能。

from sklearn.metrics import mean_squared_error

# Calculate the MSE
mse = mean_squared_error(y_test, y_pred)
print(f'Mean Squared Error: {mse}')

输出

这将生成以下输出:

Mean Squared Error: 20557383.0620152

分析 MSE 值以了解模型预测目标变量的准确程度。如果 MSE 值很高,请考虑通过调整超参数或收集新数据来更新模型。

可视化均方误差 (MSE)

要查看均方误差,请使用 MSE 值创建一个条形图。这提供了对问题严重程度的清晰直观的表示。

这里,您可以看到如何使用 matplotlib(一个流行的用于绘图的 Python 库)来绘制它:

import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

# Example data (replace these with your actual values)
# Actual values
y_test = [3, -0.5, 2, 7] 

# Predicted values
y_pred = [2.5, 0.0, 2, 8]  

# Calculate the MSE
mse = mean_squared_error(y_test, y_pred)

# Plotting the Mean Squared Error
plt.figure(figsize=(6, 4))
plt.bar(['Mean Squared Error'], [mse], color='blue')
plt.ylabel('Error Value')
plt.title('Mean Squared Error (MSE)')
plt.show()

输出

以下是上述代码的结果:

Visualize the Mean Squared Error
广告