XGBoost - 分类



XGBoost 最常见的用途之一是分类。它根据输入特征预测离散的类别标签。分类是使用 XGBClassifier 模块进行的,该模块专门用于处理分类任务。

XGBClassifier 语法

为了提高性能,我们可以调整 XGBoost 中 XGBClassifier 类的超参数。构建 XGBoost 分类器的基本语法如下所示:

model = xgb.XGBClassifier(
    objective='multi:softprob',
    num_class=num_classes,      
    max_depth=max_depth,       
    learning_rate=learning_rate,
    subsample=subsample,        
    colsample_bytree=colsample, 
    n_estimators=num_estimators
)

以下是 XGBClassifier 语法中使用的超参数的描述:

  • objective='multi:softprob' - 它是目标参数,对于多类分类是可选的,并返回每个类的概率分数。对于二分类,默认值为 'binary:logistic'。

  • num_class=num_classes - 它是多类分类任务所需的,显示数据集中存在的类别数量。

  • max_depth=max_depth - 它是可选参数,显示每棵决策树的最大深度。

  • learning_rate=learning_rate - 它是可选参数,其中步长收缩避免过拟合。

  • subsample=subsample - 它是可选参数,显示每棵树使用的样本分数。

  • colsample_bytree=colsample - 也是可选参数,显示每棵树使用的特征分数。

  • n_estimators=num_estimators - 它是必需参数,用于查找提升迭代次数并处理模型的整体复杂度。

XGBoost 分类示例

鸢尾花数据集是机器学习中非常流行的数据集。它包含 150 个鸢尾花示例,每个示例都有四个测量值,需要将三种鸢尾花物种分类。

让我们使用鸢尾花数据集来演示如何使用 XGBoost 库进行分类。

   import xgboost as xgb
   from sklearn.datasets import load_iris
   from sklearn.model_selection import train_test_split
   from sklearn.metrics import accuracy_score, classification_report

   # Load the Iris dataset
   data = load_iris()
   X, y = data.data, data.target

   # Split the data into training and test sets
   X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

   #Create an XGBoost classifier
   model = xgb.XGBClassifier()

   #Train the model on the training data
   model.fit(X_train, y_train)

   #Make predictions on the test set
   predictions = model.predict(X_test)

   #Calculate accuracy
   accuracy = accuracy_score(y_test, predictions)

   print("Model's Accuracy is:", accuracy)
   print("\nModel's Classification Report is:")
   print(classification_report(y_test, predictions, target_names=data.target_names))

输出

这将导致以下结果:

Model's Accuracy is: 1.0

Model's Classification Report is:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

总结

XGBoost 是一种强大的机器学习工具,尤其适用于分类任务。由于它速度快且具有有助于防止过拟合的功能,因此在许多情况下都能很好地工作。例如,我们使用 XGBoost 将鸢尾花分类到不同的类型中,实现了 1.0 的完美准确率。它的灵活性和效率使 XGBoost 成为许多现实生活中分类问题的绝佳选择。

广告

© . All rights reserved.