使用GridSearchCV和KerasClassifier進行超引數調整


在機器學習領域,超引數調整在最佳化模型效能方面發揮著至關重要的作用,而一種流行的超引數調整技術是將GridSearchCV與KerasClassifier結合使用。這種強大的組合允許資料科學家和機器學習從業者有效地探索和識別其深度學習模型的最佳超引數集。在本文中,我們將深入探討超引數調整的概念,瞭解GridSearchCV演算法,並探索如何將其與KerasClassifier一起使用。

瞭解超引數

超引數是機器學習模型中由資料科學家或機器學習從業人員設定的引數,而不是從資料本身學習的引數。它們定義了模型的行為和特徵,並且可以極大地影響其效能。超引數的示例包括學習率、批大小、神經網路中隱藏層的數量以及啟用函式的選擇。

超引數調整過程是開發機器學習模型的關鍵步驟。它涉及找到這些超引數的最佳值,這些值直接影響模型如何從資料中學習和泛化。透過仔細選擇和微調這些超引數,我們可以提高模型的效能,使其在進行預測或分類時更準確和可靠。

超引數調整的必要性

超引數調整具有重要意義,因為它允許我們為機器學習模型選擇最合適的超引數,從而導致其效能的顯著改進。透過微調超引數,我們可以提高模型的準確性,緩解過擬合問題,並增強其對新資料和未見資料的準確預測能力。最終,此過程使我們能夠建立一個經過良好最佳化的模型,該模型效能更好,並且能夠很好地推廣到訓練資料之外。

介紹GridSearchCV

GridSearchCV是一種用於超引數最佳化的技術。它系統地搜尋預定義的超引數集,並評估每個組合的模型效能。它窮舉嘗試每種可能的組合以識別最佳超引數集。

GridSearchCV工作流程

GridSearchCV的工作流程包括以下步驟:

  • 定義模型  指定要調整的機器學習模型。

  • 定義超引數網格  建立一個字典,其中包含要探索的超引數及其對應值。

  • 定義評分指標  選擇一個指標來評估模型的效能。

  • 執行網格搜尋  使用訓練資料和超引數網格擬合GridSearchCV物件。

  • 檢索最佳超引數  訪問GridSearchCV找到的最佳超引數。

  • 評估模型  使用最佳超引數訓練模型,並在測試資料上評估其效能。

使用KerasClassifier和GridSearchCV進行超引數調整

KerasClassifier是Keras庫中的一個包裝類,它允許我們將Keras模型與Scikit-learn的GridSearchCV一起使用。透過將KerasClassifier與GridSearchCV結合使用,我們可以輕鬆地調整使用Keras構建的深度學習模型的超引數。

要將KerasClassifier與GridSearchCV一起使用,我們需要將Keras模型定義為函式並將其傳遞給KerasClassifier。然後,我們可以透過指定超引數網格和評分指標來繼續進行常規的GridSearchCV工作流程。

以下是我們將用於使用KerasClassifier和GridSearchCV進行超引數調整的步驟:

演算法

  • 匯入所需的庫  此步驟匯入必要的庫和模組,例如NumPy、scikit-learn和Keras,以使用GridSearchCV和KerasClassifier執行超引數調整。

  • 載入資料集 

  • 將資料分成訓練集和測試集 

  • 定義一個函式來建立Keras模型:定義一個名為`create_model()`的函式來建立一個簡單的Keras模型。

  • 建立KerasClassifier物件 

  • 定義超引數網格  下面的程式定義了一個名為`param_grid`的字典,它指定了要調整的超引數及其對應值

  • 建立GridSearchCV物件

  • 將GridSearchCV物件擬合到訓練資料 

  • 列印最佳引數和得分:在測試資料上評估最佳模型 

示例

# Import the required libraries
import numpy as npp
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier

# Load the Iris dataset
irisd = load_iris()
X = irisd.data
y = irisd.target

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define a function to create the Keras model
def create_model(units=10, activation='relu'):
   model = Sequential()
   model.add(Dense(units=units, activation=activation, input_dim=4))
   model.add(Dense(units=3, activation='softmax'))
   model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
   return model

# Create the KerasClassifier object
model = KerasClassifier(build_fn=create_model)

# Define the hyperparameter grid to search over
param_grid = {
   'units': [5, 10, 15],
   'activation': ['relu', 'sigmoid']
}

# Create the GridSearchCV object
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=3)

# Fit the GridSearchCV object to the training data
grid_result = grid.fit(X_train, y_train)

# Print the best parameters and score
print("Best Parameters: ", grid_result.best_params_)
print("Best Score: ", grid_result.best_score_)

# Evaluate the best model on the test data
best_model = grid_result.best_estimator_
test_accuracy = best_model.score(X_test, y_test)
print("Test Accuracy: ", test_accuracy)

輸出

Best Parameters:  {'activation': 'sigmoid', 'units': 5}
Best Score:  0.42499999205271405
1/1 [==============================] - 0s 74ms/step - loss: 1.1070 - accuracy: 0.1667
Test Accuracy:  0.1666666716337204

使用GridSearchCV和KerasClassifier的好處

GridSearchCV和KerasClassifier的組合提供了以下幾個好處:

  • 自動超引數調整  GridSearchCV執行窮舉搜尋,使我們免於手動測試不同的組合。

  • 改進模型效能  透過識別最佳超引數集,我們可以提高模型的效能並獲得更好的結果。

  • 時間和資源效率  GridSearchCV優化了超引數搜尋過程,減少了所需的時間和計算資源。

超引數調整的最佳實踐

在執行超引數調整時,務必牢記以下最佳實踐:

  • 定義合理的搜尋空間  限制超引數的範圍,以避免低效搜尋或過擬合。

  • 利用交叉驗證  交叉驗證有助於評估模型的效能,並確保所選超引數能夠很好地泛化。

  • 考慮計算約束  注意超引數調整所需的計算資源,尤其是在大型資料集和複雜模型的情況下。

  • 跟蹤和記錄實驗  記錄不同的超引數設定及其對應的效能指標,以跟蹤進度並重現結果。

結論

總之,超引數調整是機器學習模型開發過程中至關重要的一步。GridSearchCV結合KerasClassifier提供了一種高效且自動化的方式來識別深度學習模型的最佳超引數。透過利用這種技術,資料科學家和機器學習從業者可以提高模型效能,獲得更好的結果,並節省時間和計算資源。

更新於:2023年7月11日

1K+ 閱讀量

開啟你的職業生涯

透過完成課程獲得認證

開始學習
廣告

© . All rights reserved.