機器學習 - 資訊熵



資訊熵是一個起源於熱力學的概念,後來被應用於各個領域,包括資訊理論、統計學和機器學習。在機器學習中,資訊熵被用作衡量資料集純度或隨機性的指標。具體來說,資訊熵用於決策樹演算法中,以決定如何分割資料以建立更同質的子集。在本文中,我們將討論機器學習中的資訊熵、其屬性以及在 Python 中的實現。

資訊熵被定義為系統中無序或隨機性的度量。在決策樹的背景下,資訊熵被用作衡量節點純度的指標。如果節點中的所有示例都屬於同一類,則該節點被認為是純的。相反,如果節點包含來自多個類的示例,則該節點是不純的。

要計算資訊熵,我們首先需要定義資料集中每個類的機率。設 p(i) 為示例屬於類 i 的機率。如果我們有 k 個類,則系統的總資訊熵,表示為 H(S),計算如下:

$$H\left ( S \right )=-sum\left ( p\left ( i \right )\ast log_{2}\left ( p\left ( i \right ) \right ) \right )$$

其中,求和遍及所有 k 個類。此方程稱為夏農熵。

例如,假設我們有一個包含 100 個示例的資料集,其中 60 個屬於類 A,40 個屬於類 B。則類 A 的機率為 0.6,類 B 的機率為 0.4。然後資料集的資訊熵為:

$$H\left ( S \right )=-(0.6\times log_{2}(0.6)+ 0.4\times log_{2}(0.4)) = 0.971$$

如果資料集中所有示例都屬於同一類,則資訊熵為 0,表示純節點。另一方面,如果示例在所有類中均勻分佈,則資訊熵較高,表示不純節點。

在決策樹演算法中,資訊熵用於確定每個節點的最佳分割。目標是建立導致最同質子集的分割。這是透過計算每個可能分割的資訊熵並選擇導致最低總資訊熵的分割來完成的。

例如,假設我們有一個包含兩個特徵 X1 和 X2 的資料集,目標是預測類標籤 Y。我們首先計算整個資料集的資訊熵 H(S)。接下來,我們根據每個特徵計算每個可能分割的資訊熵。例如,我們可以根據 X1 的值或 X2 的值分割資料。每個分割的資訊熵計算如下:

$$H\left ( X_{1} \right )=p_{1}\times H\left ( S_{1} \right )+p_{2}\times H\left ( S_{2} \right )H\left ( X_{2} \right )=p_{3}\times H\left ( S_{3} \right )+p_{4}\times H\left ( S_{4} \right )$$

其中,p1、p2、p3 和 p4 是每個子集的機率;H(S1)、H(S2)、H(S3) 和 H(S4) 是每個子集的資訊熵。

然後我們選擇導致最低總資訊熵的分割,它由以下公式給出:

$$H_{split}=H\left ( X_{1} \right )\, if\, H\left ( X_{1} \right )\leq H\left ( X_{2} \right );\: else\: H\left ( X_{2} \right )$$

然後使用此分割來建立決策樹的子節點,並遞迴重複此過程,直到所有節點都變為純節點或滿足停止條件。

示例

讓我們舉一個例子來了解如何在 Python 中實現它。這裡我們將使用“鳶尾花”資料集:

from sklearn.datasets import load_iris
import numpy as np

# Load iris dataset
iris = load_iris()

# Extract features and target
X = iris.data
y = iris.target

# Define a function to calculate entropy
def entropy(y):
   n = len(y)
   _, counts = np.unique(y, return_counts=True)
   probs = counts / n
   return -np.sum(probs * np.log2(probs))

# Calculate the entropy of the target variable
target_entropy = entropy(y)
print(f"Target entropy: {target_entropy:.3f}")

以上程式碼載入鳶尾花資料集,提取特徵和目標,並定義一個用於計算資訊熵的函式。entropy() 函式接受目標值的向量並返回該集合的資訊熵。

該函式首先計算集合中的示例數量和每個類的計數。然後它計算每個類的比例,並使用這些比例根據資訊熵公式計算集合的資訊熵。最後,程式碼計算鳶尾花資料集中目標變數的資訊熵並將其列印到控制檯。

輸出

執行此程式碼時,它將生成以下輸出:

Target entropy: 1.585
廣告