机器学习 - Apriori 算法



Apriori 是一种流行的算法,用于机器学习中的关联规则挖掘。它用于在交易数据库中查找频繁项集,并根据这些项集生成关联规则。该算法由 Rakesh Agrawal 和 Ramakrishnan Srikant 于 1994 年首次提出。

Apriori 算法通过迭代扫描数据库以查找越来越大的频繁项集来工作。它使用“自底向上”的方法,从单个项目开始,逐渐向候选项集添加更多项目,直到找不到更多频繁项集。该算法还采用修剪技术来减少需要检查的候选项集的数量。

以下是 Apriori 算法涉及的步骤概述 -

  • 扫描数据库以查找每个项目的支持度计数。

  • 根据最小支持度阈值生成一组频繁的 1-项集。

  • 通过组合频繁的 1-项集生成一组候选的 2-项集。

  • 再次扫描数据库以查找每个候选 2-项集的支持度计数。

  • 根据最小支持度阈值生成一组频繁的 2-项集,并修剪任何不是频繁的候选 2-项集。

  • 重复步骤 3-5 以生成候选 k-项集和频繁 k-项集,直到找不到更多频繁项集。

示例

在 Python 中,mlxtend 库提供了 Apriori 算法的实现。以下是如何在 sklearn 数据集结合 mlxtend 库在鸢尾花数据集上实现 Apriori 算法的示例。

from mlxtend.frequent_patterns import apriori
from mlxtend.preprocessing import TransactionEncoder
from sklearn import datasets

# Load the iris dataset
iris = datasets.load_iris()

# Convert the dataset into a list of transactions
transactions = []
for i in range(len(iris.data)):
   transaction = []
   transaction.append('sepal_length=' + str(iris.data[i][0]))
   transaction.append('sepal_width=' + str(iris.data[i][1]))
   transaction.append('petal_length=' + str(iris.data[i][2]))
   transaction.append('petal_width=' + str(iris.data[i][3]))
   transaction.append('target=' + str(iris.target[i]))
   transactions.append(transaction)
# Encode the transactions using one-hot encoding
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_ary, columns=te.columns_)

# Find frequent itemsets with a minimum support of 0.3
frequent_itemsets = apriori(df, min_support=0.3, use_colnames=True)

# Print the frequent itemsets
print(frequent_itemsets)

在此示例中,我们从 sklearn 加载鸢尾花数据集,其中包含有关鸢尾花的信息。我们将数据集转换为交易列表,其中每个交易代表一朵花,并包含其四个属性(萼片长度、萼片宽度、花瓣长度和花瓣宽度)以及目标标签(target)的值。然后,我们使用独热编码对交易进行编码,并使用 mlxtend 中的 apriori 函数找到最小支持度为 0.3 的频繁项集。

此代码的输出将显示频繁项集及其对应支持度计数。由于鸢尾花数据集相对较小,我们只找到一个频繁项集 -

输出

   support   itemsets
0  0.333333  (target=0)
1  0.333333  (target=1)
2  0.333333  (target=2)

这表明数据集中 33% 的交易同时包含花瓣长度值为 1.4 和目标标签为 0(对应于鸢尾花数据集中山鸢尾物种)。

Apriori 算法广泛用于市场购物篮分析,以识别客户购买行为中的模式。例如,零售商可以使用该算法查找经常一起购买的商品,以便共同推广以增加销售额。该算法也可用于医疗保健、金融和社交媒体等其他领域,以识别模式并从大型数据集中生成见解。

广告