XGBoost - 分类
XGBoost 最常见的用途之一是分类。它根据输入特征预测离散的类别标签。分类是使用 XGBClassifier 模块进行的,该模块专门用于处理分类任务。
XGBClassifier 语法
为了提高性能,我们可以调整 XGBoost 中 XGBClassifier 类的超参数。构建 XGBoost 分类器的基本语法如下所示:
model = xgb.XGBClassifier(
objective='multi:softprob',
num_class=num_classes,
max_depth=max_depth,
learning_rate=learning_rate,
subsample=subsample,
colsample_bytree=colsample,
n_estimators=num_estimators
)
以下是 XGBClassifier 语法中使用的超参数的描述:
objective='multi:softprob' - 它是目标参数,对于多类分类是可选的,并返回每个类的概率分数。对于二分类,默认值为 'binary:logistic'。
num_class=num_classes - 它是多类分类任务所需的,显示数据集中存在的类别数量。
max_depth=max_depth - 它是可选参数,显示每棵决策树的最大深度。
learning_rate=learning_rate - 它是可选参数,其中步长收缩避免过拟合。
subsample=subsample - 它是可选参数,显示每棵树使用的样本分数。
colsample_bytree=colsample - 也是可选参数,显示每棵树使用的特征分数。
n_estimators=num_estimators - 它是必需参数,用于查找提升迭代次数并处理模型的整体复杂度。
XGBoost 分类示例
鸢尾花数据集是机器学习中非常流行的数据集。它包含 150 个鸢尾花示例,每个示例都有四个测量值,需要将三种鸢尾花物种分类。
让我们使用鸢尾花数据集来演示如何使用 XGBoost 库进行分类。
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
# Load the Iris dataset
data = load_iris()
X, y = data.data, data.target
# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
#Create an XGBoost classifier
model = xgb.XGBClassifier()
#Train the model on the training data
model.fit(X_train, y_train)
#Make predictions on the test set
predictions = model.predict(X_test)
#Calculate accuracy
accuracy = accuracy_score(y_test, predictions)
print("Model's Accuracy is:", accuracy)
print("\nModel's Classification Report is:")
print(classification_report(y_test, predictions, target_names=data.target_names))
输出
这将导致以下结果:
Model's Accuracy is: 1.0
Model's Classification Report is:
precision recall f1-score support
setosa 1.00 1.00 1.00 10
versicolor 1.00 1.00 1.00 9
virginica 1.00 1.00 1.00 11
accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30
总结
XGBoost 是一种强大的机器学习工具,尤其适用于分类任务。由于它速度快且具有有助于防止过拟合的功能,因此在许多情况下都能很好地工作。例如,我们使用 XGBoost 将鸢尾花分类到不同的类型中,实现了 1.0 的完美准确率。它的灵活性和效率使 XGBoost 成为许多现实生活中分类问题的绝佳选择。