機器學習 - 網格搜尋



網格搜尋是一種機器學習中的超引數調整技術,它有助於為給定模型找到最佳的超引數組合。它的工作原理是定義一個超引數網格,然後使用所有可能的超引數組合訓練模型,以找到效能最佳的組合。

換句話說,網格搜尋是一種窮舉搜尋方法,其中定義了一組超引數,並在這些超引數的所有可能組合上執行搜尋,以找到提供最佳效能的最佳值。

Python中的實現

在Python中,可以使用sklearn模組中的GridSearchCV類實現網格搜尋。GridSearchCV類以模型、要調整的超引數和評分函式作為輸入。然後,它對所有可能的超引數組合執行窮舉搜尋,並返回提供最佳分數的最佳超引數集。

以下是使用GridSearchCV類在Python中實現網格搜尋的示例:

示例

from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

# Generate a sample dataset
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2)

# Define the model and the hyperparameters to tune
model = RandomForestClassifier()
hyperparameters = {'n_estimators': [10, 50, 100], 'max_depth': [None, 5, 10]}

# Define the Grid Search object and fit the data
grid_search = GridSearchCV(model, hyperparameters, scoring='accuracy', cv=5)
grid_search.fit(X, y)

# Print the best hyperparameters and the corresponding score
print("Best hyperparameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)

在此示例中,我們定義了一個RandomForestClassifier模型和一組要調整的超引數,即樹的數量(n_estimators)和每棵樹的最大深度(max_depth)。然後,我們建立一個GridSearchCV物件並使用fit()方法擬合數據。最後,我們列印最佳超引數集和相應的分數。

輸出

執行此程式碼時,將產生以下輸出:

Best hyperparameters: {'max_depth': None, 'n_estimators': 10}
Best score: 0.953
廣告

© . All rights reserved.