機器學習 - 過擬合



過擬合是指模型學習訓練資料中的噪聲,而不是潛在模式。這會導致模型在訓練資料上表現良好,但在新資料上表現不佳。從本質上講,模型變得過於專門化,無法泛化到新資料。

在使用複雜模型(例如深度神經網路)時,過擬合是一個常見問題。這些模型具有許多引數,並且能夠非常緊密地擬合訓練資料。但是,這通常是以犧牲泛化效能為代價的。

過擬合的原因

有幾個因素可能導致過擬合 -

  • 複雜模型 - 如前所述,複雜模型比簡單模型更容易過擬合。這是因為它們具有更多引數,並且能夠更緊密地擬合訓練資料。

  • 訓練資料有限 - 當訓練資料不足時,模型難以學習潛在模式,反而可能學習資料中的噪聲。

  • 訓練資料不具有代表性 - 如果訓練資料不能代表模型試圖解決的問題,那麼模型可能會學習不相關的模式,這些模式無法很好地泛化到新資料。

  • 缺乏正則化 - 正則化是一種透過向成本函式新增懲罰項來防止過擬合的技術。如果不存在此懲罰項,則模型更容易過擬合。

防止過擬合的技術

有幾種技術可用於防止機器學習中的過擬合 -

  • 交叉驗證 - 交叉驗證是一種用於評估模型在新資料上的效能的技術。它涉及將資料分成幾個子集,並依次使用每個子集作為驗證集,同時在剩餘資料上進行訓練。這有助於確保模型能夠很好地泛化到新資料。

  • 提前停止 - 提前停止是一種透過在訓練過程完全收斂之前停止訓練過程來防止模型過擬合的技術。這是透過在訓練期間監控驗證誤差,並在誤差停止改善時停止訓練來完成的。

  • 正則化 - 正則化是一種透過向成本函式新增懲罰項來防止過擬合的技術。懲罰項鼓勵模型具有較小的權重,並有助於防止其擬合訓練資料中的噪聲。

  • Dropout - Dropout 是一種用於深度神經網路中防止過擬合的技術。它涉及在訓練期間隨機丟棄一些神經元,這迫使剩餘的神經元學習更健壯的特徵。

示例

以下是使用 Keras 在 Python 中實現提前停止和 L2 正則化的示例 -

from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping
from keras import regularizers

# define the model architecture
model = Sequential()
model.add(Dense(64, input_dim=X_train.shape[1], activation='relu', kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(1, activation='sigmoid'))

# compile the model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# set up early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=5)

# train the model with early stopping and L2 regularization
history = model.fit(X_train, y_train, validation_split=0.2, epochs=100, batch_size=64, callbacks=[early_stopping])

在此程式碼中,我們使用了 Keras 中的 Sequential 模型來定義模型架構,並且使用 kernel_regularizer 引數向前兩層添加了 L2 正則化。我們還使用 Keras 中的 EarlyStopping 類設定了一個提前停止回撥,它將監控驗證損失並在其停止改善 5 個時期後停止訓練。

在訓練期間,我們傳入 X_train 和 y_train 資料以及 0.2 的驗證拆分以監控驗證損失。我們還設定了 64 的批大小並最多訓練 100 個時期。

輸出

執行此程式碼時,它將生成如下所示的輸出 -

Train on 323 samples, validate on 81 samples
Epoch 1/100
323/323 [==============================] - 0s 792us/sample - loss: -8.9033 - accuracy: 0.0000e+00 - val_loss: -15.1467 - val_accuracy: 0.0000e+00
Epoch 2/100
323/323 [==============================] - 0s 46us/sample - loss: -20.4505 - accuracy: 0.0000e+00 - val_loss: -25.7619 - val_accuracy: 0.0000e+00
Epoch 3/100
323/323 [==============================] - 0s 43us/sample - loss: -31.9206 - accuracy: 0.0000e+00 - val_loss: -36.8155 - val_accuracy: 0.0000e+00
Epoch 4/100
323/323 [==============================] - 0s 46us/sample - loss: -44.2281 - accuracy: 0.0000e+00 - val_loss: -49.0378 - val_accuracy: 0.0000e+00
Epoch 5/100
323/323 [==============================] - 0s 52us/sample - loss: -58.3326 - accuracy: 0.0000e+00 - val_loss: -62.9369 - val_accuracy: 0.0000e+00
Epoch 6/100
323/323 [==============================] - 0s 40us/sample - loss: -74.2131 - accuracy: 0.0000e+00 - val_loss: -78.7068 - val_accuracy: 0.0000e+00
-----continue

透過使用提前停止和 L2 正則化,我們可以幫助防止過擬合併提高模型的泛化效能。

廣告