如何使用机器学习预测葡萄酒质量?


本教程将从在线资源(如 Kaggle)获取葡萄酒质量数据集。首选数据集是“葡萄酒质量数据集”,可在 "https://www.kaggle.com/datasets/yasserh/wine-quality-dataset." 获取。

该数据集包含一个 .csv 文件,其中包含各种类型的葡萄酒,例如“固定酸度”、“挥发性酸度”、“pH 值”、“密度”等等。在此数据集中,在初始阶段删除了字段名称“质量”,并且进一步训练了模型。

以下是预测葡萄酒质量的 Python 代码。

  • 导入必要的库。

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
  • 导入葡萄酒质量数据集

wine = pd.read_csv('/Users/someswarpal/Downloads/WineQT.csv')
  • 删除名为“quality”的列。

X = wine.drop(columns=['quality'])
y = wine['quality']
  • 将数据拆分为测试集和训练集。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  • 创建线性回归模型

model = LinearRegression()
  • 训练模型

model.fit(X_train, y_train)
  • 对训练集进行预测。

y_pred = model.predict(X_test)
  • 评估模型

mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error:", mse)
  • 计算每个类别的平均质量

mean_quality = wine.groupby('quality')['quality'].mean()

输出

Mean Squared Error: 0.38242835212919696
  • 找到平均质量最高的类别

best_quality = mean_quality.idxmax()
best_mean_quality = mean_quality.max()
  • 打印最佳葡萄酒的摘要。

print("Summary of Wine Quality:")
print("----------------------------")
print("Best Wine Quality Category:", best_quality)
print("Mean Quality Score:", best_mean_quality)

输出

Summary of Wine Quality:
   ----------------------------
   Best Wine Quality Category: 8
   Mean Quality Score: 8.0
  • 找到平均质量最低的类别

worst_quality = mean_quality.idxmin()
worst_mean_quality = mean_quality.min()
  • 打印最差葡萄酒的摘要

示例

print("Summary of Wine Quality:")
print("----------------------------")
print("Worst Wine Quality Category:", worst_quality)
print("Mean Quality Score:", worst_mean_quality)

输出

Summary of Wine Quality:
----------------------------
Worst Wine Quality Category: 3
Mean Quality Score: 3.0

结论

总之,该代码以多种方式分析和显示关于葡萄酒质量的集合中的数据。它首先读取数据集并将其分成输入特征 (X) 和目标变量 (y)。然后使用训练集来创建和训练线性回归模型。然后在测试集上进行预测,并使用均方误差来衡量模型的性能。

该代码还确定数据集中每个类别的平均质量,并找到平均质量最高的类别。可以创建一些图像,例如散点图、直方图、箱线图、条形图、线形图、相关性热图和饼图。这些图显示了不同因素如何影响葡萄酒的质量。

总的来说,该代码对葡萄酒质量数据集进行了深入的研究,从建模和评估数据到显示数据的分布方式以及它们之间的关系。它展示了如何使用流行的数据分析和可视化库(如 Pandas、NumPy、sci-kit-learn、matplotlib 和 Seaborn)来使分析过程更容易,并提供有用的见解来理解数据集。

更新于: 2023年10月12日

170 次查看

开启您的 职业生涯

通过完成课程获得认证

开始学习
广告