提升機器學習模型效能(續…)



透過演算法調整提升效能

眾所周知,機器學習模型的引數化方式使得其行為可以針對特定問題進行調整。演算法調整意味著找到這些引數的最佳組合,從而提高機器學習模型的效能。此過程有時稱為超引數最佳化,演算法本身的引數稱為超引數,而機器學習演算法找到的係數稱為引數。

在這裡,我們將討論Python Scikit-learn提供的一些演算法引數調整方法。

網格搜尋引數調整

這是一種引數調整方法。這種方法的關鍵在於它系統地構建和評估網格中指定的每個可能的演算法引數組合的模型。因此,可以說這種演算法具有搜尋特性。

示例

在下面的Python示例中,我們將使用sklearn的GridSearchCV類進行網格搜尋,以評估皮馬印第安人糖尿病資料集上嶺迴歸演算法的各種alpha值。

首先,匯入所需的包,如下所示:

import numpy
from pandas import read_csv
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV

現在,我們需要像之前的例子一樣載入皮馬糖尿病資料集:

path = r"C:\pima-indians-diabetes.csv"
headernames = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
data = read_csv(path, names=headernames)
array = data.values
X = array[:,0:8]
Y = array[:,8]

接下來,評估各種alpha值,如下所示:

alphas = numpy.array([1,0.1,0.01,0.001,0.0001,0])
param_grid = dict(alpha=alphas)

現在,我們需要在我們的模型上應用網格搜尋:

model = Ridge()
grid = GridSearchCV(estimator=model, param_grid=param_grid)
grid.fit(X, Y)

使用以下指令碼行列印結果:

print(grid.best_score_)
print(grid.best_estimator_.alpha)

輸出

0.2796175593129722
1.0

上面的輸出給我們提供了最佳分數以及在網格中達到該分數的引數集。在這種情況下,alpha值為1.0。

隨機搜尋引數調整

這是一種引數調整方法。這種方法的關鍵在於它在固定的迭代次數內從隨機分佈中取樣演算法引數。

示例

在下面的Python示例中,我們將使用sklearn的RandomizedSearchCV類進行隨機搜尋,以評估皮馬印第安人糖尿病資料集上嶺迴歸演算法的0到1之間的不同alpha值。

首先,匯入所需的包,如下所示:

import numpy
from pandas import read_csv
from scipy.stats import uniform
from sklearn.linear_model import Ridge
from sklearn.model_selection import RandomizedSearchCV

現在,我們需要像之前的例子一樣載入皮馬糖尿病資料集:

path = r"C:\pima-indians-diabetes.csv"
headernames = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']
data = read_csv(path, names=headernames)
array = data.values
X = array[:,0:8]
Y = array[:,8]

接下來,評估嶺迴歸演算法上的各種alpha值,如下所示:

param_grid = {'alpha': uniform()}
model = Ridge()
random_search = RandomizedSearchCV(estimator=model, param_distributions=param_grid, n_iter=50,
random_state=7)
random_search.fit(X, Y)

使用以下指令碼行列印結果:

print(random_search.best_score_)
print(random_search.best_estimator_.alpha)

輸出

0.27961712703051084
0.9779895119966027

上面的輸出與網格搜尋類似,也給了我們最佳分數。

廣告