机器学习 - 信息熵



信息熵是一个起源于热力学的概念,后来被应用于各个领域,包括信息论、统计学和机器学习。在机器学习中,信息熵被用作衡量数据集纯度或随机性的指标。具体来说,信息熵用于决策树算法中,以决定如何分割数据以创建更同质的子集。在本文中,我们将讨论机器学习中的信息熵、其属性以及在 Python 中的实现。

信息熵被定义为系统中无序或随机性的度量。在决策树的背景下,信息熵被用作衡量节点纯度的指标。如果节点中的所有示例都属于同一类,则该节点被认为是纯的。相反,如果节点包含来自多个类的示例,则该节点是不纯的。

要计算信息熵,我们首先需要定义数据集中每个类的概率。设 p(i) 为示例属于类 i 的概率。如果我们有 k 个类,则系统的总信息熵,表示为 H(S),计算如下:

$$H\left ( S \right )=-sum\left ( p\left ( i \right )\ast log_{2}\left ( p\left ( i \right ) \right ) \right )$$

其中,求和遍及所有 k 个类。此方程称为香农熵。

例如,假设我们有一个包含 100 个示例的数据集,其中 60 个属于类 A,40 个属于类 B。则类 A 的概率为 0.6,类 B 的概率为 0.4。然后数据集的信息熵为:

$$H\left ( S \right )=-(0.6\times log_{2}(0.6)+ 0.4\times log_{2}(0.4)) = 0.971$$

如果数据集中所有示例都属于同一类,则信息熵为 0,表示纯节点。另一方面,如果示例在所有类中均匀分布,则信息熵较高,表示不纯节点。

在决策树算法中,信息熵用于确定每个节点的最佳分割。目标是创建导致最同质子集的分割。这是通过计算每个可能分割的信息熵并选择导致最低总信息熵的分割来完成的。

例如,假设我们有一个包含两个特征 X1 和 X2 的数据集,目标是预测类标签 Y。我们首先计算整个数据集的信息熵 H(S)。接下来,我们根据每个特征计算每个可能分割的信息熵。例如,我们可以根据 X1 的值或 X2 的值分割数据。每个分割的信息熵计算如下:

$$H\left ( X_{1} \right )=p_{1}\times H\left ( S_{1} \right )+p_{2}\times H\left ( S_{2} \right )H\left ( X_{2} \right )=p_{3}\times H\left ( S_{3} \right )+p_{4}\times H\left ( S_{4} \right )$$

其中,p1、p2、p3 和 p4 是每个子集的概率;H(S1)、H(S2)、H(S3) 和 H(S4) 是每个子集的信息熵。

然后我们选择导致最低总信息熵的分割,它由以下公式给出:

$$H_{split}=H\left ( X_{1} \right )\, if\, H\left ( X_{1} \right )\leq H\left ( X_{2} \right );\: else\: H\left ( X_{2} \right )$$

然后使用此分割来创建决策树的子节点,并递归重复此过程,直到所有节点都变为纯节点或满足停止条件。

示例

让我们举一个例子来了解如何在 Python 中实现它。这里我们将使用“鸢尾花”数据集:

from sklearn.datasets import load_iris
import numpy as np

# Load iris dataset
iris = load_iris()

# Extract features and target
X = iris.data
y = iris.target

# Define a function to calculate entropy
def entropy(y):
   n = len(y)
   _, counts = np.unique(y, return_counts=True)
   probs = counts / n
   return -np.sum(probs * np.log2(probs))

# Calculate the entropy of the target variable
target_entropy = entropy(y)
print(f"Target entropy: {target_entropy:.3f}")

以上代码加载鸢尾花数据集,提取特征和目标,并定义一个用于计算信息熵的函数。entropy() 函数接受目标值的向量并返回该集合的信息熵。

该函数首先计算集合中的示例数量和每个类的计数。然后它计算每个类的比例,并使用这些比例根据信息熵公式计算集合的信息熵。最后,代码计算鸢尾花数据集中目标变量的信息熵并将其打印到控制台。

输出

执行此代码时,它将生成以下输出:

Target entropy: 1.585
广告