使用Python實現決策樹


決策樹是一種主要應用於資料分類場景的演算法。它是一種樹形結構,其中每個節點代表特徵,每條邊代表做出的決策。從根節點開始,我們繼續評估特徵進行分類,並做出遵循特定邊的決策。每當出現新的資料點時,都會反覆應用此方法,然後在研究或應用於分類場景的所有必要特徵後得出最終結論。因此,決策樹演算法是一種監督學習模型,用於根據一系列訓練變數預測因變數。

示例

我們將使用kaggle上提供的藥物測試資料。第一步,我們將使用pandas從csv檔案讀取資料,並檢視其內容和結構。

import pandas as pd

datainput = pd.read_csv("drug.csv", delimiter=",") #https://www.kaggle.com/gangliu/drugsets
print(datainput)

執行以上程式碼得到以下結果

   Age Sex BP   Cholesterol Na_to_K  Drug
0  23   F HIGH  HIGH        25.355   drugY
1  47   M LOW   HIGH        13.093   drugC
2  47   M LOW   HIGH        10.114  drugC
3  28   F NORMAL HIGH        7.798  drugX
4  61   F LOW    HIGH       18.043  drugY
.. ... .. ... ... ... ...
195 56  F LOW    HIGH       11.567  drugC
196 16  M LOW    HIGH       12.006  drugC
197 52  M NORMAL HIGH     9.894 drugX
[200 rows x 6 columns]

資料預處理

下一步,我們對上述資料進行預處理,以獲取資料中不同文字值的數值。這有助於訓練和測試關於針對給定年齡、性別、血壓等值使用某種藥物的決策的樣本資料。

示例

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

datainput = pd.read_csv("drug.csv", delimiter=",")

X = datainput[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values

from sklearn import preprocessing
label_gender = preprocessing.LabelEncoder()
label_gender.fit(['F','M'])
X[:,1] = label_gender.transform(X[:,1])

label_BP = preprocessing.LabelEncoder()
label_BP.fit([ 'LOW', 'NORMAL', 'HIGH'])
X[:,2] = label_BP.transform(X[:,2])

label_Chol = preprocessing.LabelEncoder()
label_Chol.fit([ 'NORMAL', 'HIGH'])
X[:,3] = label_Chol.transform(X[:,3])

# Printing the first 6 records
print(X[0:6])

執行以上程式碼得到以下結果:

[[23 0 0 0 25.355]
   [47 1 1 0 13.093]
   [47 1 1 0 10.113999999999999]
   [28 0 2 0 7.797999999999999]
   [61 0 1 0 18.043]
   [22 0 2 0 8.607000000000001]
]

轉換因變數

接下來,我們還將因變數轉換為數值,以便它可以用於訓練和評估資料集。

示例

import pandas as pd

datainput = pd.read_csv("drug.csv", delimiter=",")
X = datainput[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values

y = datainput["Drug"]

print(y[0:6])

輸出

執行以上程式碼得到以下結果

0    drugY
1    drugC
2    drugC
3    drugX
4    drugY
5    drugX
Name: Drug, dtype: object

訓練資料集

接下來,我們使用提供的30%的資料作為訓練資料集。這將作為建立其餘70%(我們將稱為測試資料)分類的基礎。

示例

import pandas as pd
from sklearn.model_selection import train_test_split

datainput = pd.read_csv("drug.csv", delimiter=",")

X = datainput[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values

y = datainput["Drug"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=3)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

輸出

執行以上程式碼得到以下結果

(140, 5)
(60, 5)
(140,)
(60,)

從訓練資料集獲取結果

接下來,我們可以應用決策樹來檢視訓練資料集的結果。在這裡,我們根據輸入建立一個樹,並使用稱為熵的標準。最後,我們計算決策樹的準確性。

示例

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics

datainput = pd.read_csv("drug.csv", delimiter=",")

X = datainput[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values

# Data Preprocessing
from sklearn import preprocessing

label_gender = preprocessing.LabelEncoder()
label_gender.fit(['F', 'M'])
X[:, 1] = label_gender.transform(X[:, 1])

label_BP = preprocessing.LabelEncoder()
label_BP.fit(['LOW', 'NORMAL', 'HIGH'])
X[:, 2] = label_BP.transform(X[:, 2])

label_Chol = preprocessing.LabelEncoder()
label_Chol.fit(['NORMAL', 'HIGH'])
X[:, 3] = label_Chol.transform(X[:, 3])

y = datainput["Drug"]

# train_test_split
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=3)

drugTree = DecisionTreeClassifier(criterion="entropy", max_depth=4)

drugTree.fit(X_train, y_train)
predicted = drugTree.predict(X_test)

print(predicted)

print("\nDecisionTrees's Accuracy: ", metrics.accuracy_score(y_test, predicted))

輸出

執行以上程式碼得到以下結果

['drugY' 'drugX' 'drugX' 'drugX' 'drugX' 'drugC' 'drugY' 'drugA' 'drugB'
'drugA' 'drugY' 'drugA' 'drugY' 'drugY' 'drugX' 'drugY' 'drugX' 'drugX'
'drugB' 'drugX' 'drugX' 'drugY' 'drugY' 'drugY' 'drugX' 'drugB' 'drugY'
'drugY' 'drugA' 'drugX' 'drugB' 'drugC' 'drugC' 'drugX' 'drugX' 'drugC'
'drugY' 'drugX' 'drugX' 'drugX' 'drugA' 'drugY' 'drugC' 'drugY' 'drugA'
'drugY' 'drugY' 'drugY' 'drugY' 'drugY' 'drugB' 'drugX' 'drugY' 'drugX'
'drugY' 'drugY' 'drugA' 'drugX' 'drugY' 'drugX']

DecisionTrees's Accuracy: 0.9833333333333333

更新於:2020年1月2日

1K+ 次瀏覽

啟動你的職業生涯

完成課程獲得認證

開始學習
廣告
© . All rights reserved.