TensorFlow 中的線性分類器


由於其簡單性和有效性,線性分類器長期以來一直是機器學習的支柱。一個名為 TensorFlow 的流行機器學習框架為這些模型提供了全面的支援。本文介紹了 TensorFlow 的線性分類器,解釋了它們的工作原理以及如何在應用程式中使用它們。

瞭解線性分類器

線性分類器使用直線、平面或超平面將資料劃分為不同的類別。由於分割線相對於輸入空間是線性的,因此稱為“線性”邊界。二元或多類線性分類器應用於輸入和輸出之間關係大致線性的問題。

TensorFlow:簡要概述

TensorFlow 是一個開源機器學習框架,由 Google Brain 團隊建立。它提供了一個完整的工具、庫和社群資源生態系統,用於構建機器學習演算法和模型。TensorFlow 的主要優勢在於它能夠進行高階和低階計算,這使得使用者能夠相對輕鬆地構建複雜的機器學習模型。

使用 TensorFlow 實現線性分類器

為了建立線性分類器,TensorFlow 提供了 tf.estimator API,特別是 tf.estimator.LinearClassifier。它包含構建、評估、預測和使用模型所涉及的所有推理。

安裝 TensorFlow

首先確保 TensorFlow 已安裝。使用 pip 來完成此操作

pip install tensorflow

示例 1:簡單的線性分類器

讓我們來看一個簡單的示例,其中我們使用線性分類器對 Iris 資料集進行分類。Iris 多變數資料集是由英國統計學家和生物學家 Ronald Fisher 開發的。它包含三種鳶尾花物種的 50 個樣本。

首先讓我們載入 Iris 資料集,然後匯入所需的庫 -

import tensorflow as tf
from sklearn import datasets

# Load Iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target

定義特徵列後,我們將構建線性分類器

# Define feature columns
feature_columns = [tf.feature_column.numeric_column('x', shape=X.shape[1:])]

# Build linear classifier
classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns, n_classes=3)

# Define input function
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
   x={'x': X},
   y=y,
   num_epochs=None,
   shuffle=True
)

# Train the classifier
classifier.train(input_fn=input_fn, steps=5000)

此程式碼中的特徵列首先被定義,它們描述了資料集中每個特徵的資料型別。然後,我們使用 tf.estimator.LinearClassifier 構建線性分類器。我們使用 numpy_input_fn 函式將我們的資料饋送到分類器,然後使用 .train() 方法訓練分類器。

示例 2:評估分類器

現在分類器已經過訓練,我們可以評估其效能。在這個例子中,我們將使用 Iris 資料集的一部分,這些資料我們沒有用於訓練 -

# Define the test inputs
test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
   x={'x': X_test},
   y=y_test,
   num_epochs=1,
   shuffle=False
)

# Evaluate accuracy
accuracy_score = classifier.evaluate(input_fn=test_input_fn)['accuracy']

print(f'\nTest Accuracy: {accuracy_score}\n')

在此示例中,我們為測試資料建立了一個新的輸入函式,然後使用 .evaluate() 方法評估分類器的準確性。

示例 3:進行預測

我們可以使用我們訓練過的分類器對新資料進行預測。讓我們透過使用我們的分類器預測新花的種類來演示這一點

# New flower data
new_flower = np.array([[5.1, 3.3, 1.7, 0.5]], dtype=float)

# Define the input function for predictions
predict_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
   x={'x': new_flower},
   num_epochs=1,
   shuffle=False
)

# Get the predictions
predictions = list(classifier.predict(input_fn=predict_input_fn))
predicted_class = predictions[0]['class_ids'][0]

print(f'\nPredicted Iris Class: {predicted_class}\n')

在此示例中,我們使用四個度量標準來定義一朵新花。然後,我們使用訓練過的分類器預測新花的類別。結果是預測的鳶尾花種類。

結論

線性分類器是最簡單但最有效的機器學習模型之一,尤其是在處理線性可分資料時。透過提供一種簡單而靈活的方法來建立線性分類器,TensorFlow 的 tf.estimator API 使得在您自己的應用程式中使用這些模型變得更加容易。

在這篇文章中,介紹了線性分類器的概念,並使用 TensorFlow 演示瞭如何使用它們。我們討論瞭如何構建分類器、評估其有效性和使用新資料進行預測。這些示例顯示了構建和應用線性分類器的基本步驟。

請記住,結果的質量在很大程度上取決於您使用的資料集以及用於準備它的方法,例如特徵選擇和資料歸一化。始終使用測試集評估分類器,以獲得對分類器效能的有意義的理解。

TensorFlow 是一款非常強大的工具,它提供了一系列功能來構建複雜的機器學習模型。這僅僅是它對線性分類器的支援的冰山一角。隨著您進行更多的研究,您將發現各種最先進的方法和技術來構建可靠且有效的機器學習模型。

更新於: 2023年7月18日

253 次瀏覽

開啟您的 職業生涯

透過完成課程獲得認證

開始學習
廣告
© . All rights reserved.