- LightGBM 教程
- LightGBM - 首页
- LightGBM - 概述
- LightGBM - 架构
- LightGBM - 安装
- LightGBM - 核心参数
- LightGBM - Boosting算法
- LightGBM - 树增长策略
- LightGBM - 数据集结构
- LightGBM - 二元分类
- LightGBM - 回归
- LightGBM - 排序
- LightGBM - Python 实现
- LightGBM - 参数调整
- LightGBM - 绘图功能
- LightGBM - 早停训练
- LightGBM - 特征交互约束
- LightGBM 与其他 Boosting 算法的比较
- LightGBM 有用资源
- LightGBM - 有用资源
- LightGBM - 讨论
LightGBM - 回归
流行的机器学习方法 LightGBM(轻量级梯度提升机)用于回归和分类应用。当用于回归时,它会创建一系列决策树,每棵树都试图通过减少前一棵树的误差来最小化损失函数(例如均方误差)。
LightGBM 如何用于回归?
LightGBM 的基础,梯度提升,按顺序依次创建多个决策树。每棵树都努力纠正前一棵树所犯的错误。
与其他按层增长树的提升算法不同,LightGBM 按叶节点增长树。这意味着在扩展模型时,它会优化损失减少(即,最能改进模型的叶子节点)。这会产生更深、更准确的树,但需要仔细调整以避免过拟合。
为了减少预期结果和实际结果之间的差异,LightGBM 使用两种类型的回归任务损失函数——均方误差 (MSE) 和平均绝对误差 (MAE)。
何时使用 LightGBM 回归
以下是一些可以使用 LightGBM 进行回归的情况:
当给定大型数据集时。
当需要快速高效的模型时。
当您的数据包含大量特征(列)或缺失值时。
使用 LightGBM 进行回归的示例
现在让我们看看如何创建一个 LightGBM 回归模型。这些步骤将帮助您了解该过程的每个步骤是如何工作的。
步骤 1 - 安装所需的库
在开始之前,请确保您已安装必要的库。Scikit-learn 用于数据处理,lightgbm 用于 LightGBM 模型。
pip install pandas scikit-learn lightgbm
步骤 2 - 加载数据
首先,使用 pandas 加载数据集。此数据集包含与健康相关的资料,包括年龄、性别、BMI、子女数量、居住地、吸烟状况和医疗费用。
import pandas as pd # Load the dataset from your local file path data = pd.read_csv('/My Docs/Python/medical_cost.csv') # Display the first few rows of the dataset print(data.head())
输出
这将产生以下结果:
age sex bmi children smoker region charges 0 19 female 27.900 0 yes southwest 16884.92400 1 18 male 33.770 1 no southeast 1725.55230 2 28 male 33.000 3 no southeast 4449.46200 3 33 male 22.705 0 no northwest 21984.47061 4 32 male 28.880 0 no northwest 3866.85520
步骤 3 - 分离特征和目标变量
现在正在分离目标变量 (y) 和特征 (X)。在本例中,我们希望使用其他特征来预测“费用”列。
# 'charges' is the target column that we want to predict # All columns except 'charges' are features X = data.drop('charges', axis=1) # The 'charges' column is the target variable y = data['charges']
步骤 4 - 处理分类数据
数据集中的分类特征(性别、吸烟者和地区)需要转换为数值格式,因为 LightGBM 使用数值数据。独热编码用于将这些类别列转换为二进制格式(0 和 1)。
# Convert categorical variables to numerical X = pd.get_dummies(X, drop_first=True)
这里:
pd.get_dummies() 用于为每个类别生成额外的二进制列。
drop_first=True 通过消除每个分类变量的第一个类别来避免多重共线性。
步骤 5 - 分割数据
为了了解模型的性能,我们将数据分成两组——训练集(占数据的 80%)和测试集(占数据的 20%)。
train_test_split() 用于随机分割数据,同时保持给定的比例 (test_size=0.2)。
使用 random_state = 42 可以确保分割结果可重现。
from sklearn.model_selection import train_test_split # Split the data into training and test sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
步骤 6:初始化 LightGBM 回归器
现在我们将为回归初始化 LightGBM 模型。LGBMRegressor 是 LightGBM 的实现,专门用于回归任务。LGBMRegressor 模型非常高效和灵活,可以有效地处理大型数据集。
from lightgbm import LGBMRegressor # Initialize the LightGBM regressor model model = LGBMRegressor()
步骤 7:训练模型
接下来,我们将使用训练数据 (X_train 和 y_train) 来训练模型。这里使用 fit() 方法通过查找训练数据中的模式并预测目标变量(费用)来训练模型。
# Train the model on the training data model.fit(X_train, y_train)
输出
运行上述代码后,我们将得到以下结果:
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001000 seconds. You can set `force_col_wise=true` to remove the overhead. [LightGBM] [Info] Total Bins 319 [LightGBM] [Info] Number of data points in the train set: 1070, number of used features: 8 [LightGBM] [Info] Start training from score 13346.089733 LGBMRegressori LGBMRegressor()
步骤 8:进行预测
训练后,我们使用模型对测试集 (X_test) 进行预测。model.predict(X_test) 根据从训练数据中学习到的模式生成测试集的预测值。
# Predict on the test set y_pred = model.predict(X_test)
步骤 9:评估模型
我们将使用均方误差 (MSE) 来衡量模型的性能,这是一个常用的回归指标。均方误差或 MSE 计算的是预期值和实际值之间的差异的平方平均值。较低的 MSE 值表示更好的性能。
from sklearn.metrics import mean_squared_error # Calculate the MSE mse = mean_squared_error(y_test, y_pred) print(f'Mean Squared Error: {mse}')
输出
这将生成以下输出:
Mean Squared Error: 20557383.0620152
分析 MSE 值以了解模型预测目标变量的准确程度。如果 MSE 值很高,请考虑通过调整超参数或收集新数据来更新模型。
可视化均方误差 (MSE)
要查看均方误差,请使用 MSE 值创建一个条形图。这提供了对问题严重程度的清晰直观的表示。
这里,您可以看到如何使用 matplotlib(一个流行的用于绘图的 Python 库)来绘制它:
import matplotlib.pyplot as plt from sklearn.metrics import mean_squared_error # Example data (replace these with your actual values) # Actual values y_test = [3, -0.5, 2, 7] # Predicted values y_pred = [2.5, 0.0, 2, 8] # Calculate the MSE mse = mean_squared_error(y_test, y_pred) # Plotting the Mean Squared Error plt.figure(figsize=(6, 4)) plt.bar(['Mean Squared Error'], [mse], color='blue') plt.ylabel('Error Value') plt.title('Mean Squared Error (MSE)') plt.show()
输出
以下是上述代码的结果: