機器學習 - 迭代週期 (Epoch)



在機器學習中,迭代週期(Epoch)指的是在模型訓練過程中完整遍歷整個訓練資料集的一次迭代。簡單來說,就是在訓練階段演算法遍歷整個資料集的次數。

在訓練過程中,演算法對訓練資料進行預測,計算損失,並更新模型引數以減少損失。目標是透過最小化損失函式來最佳化模型的效能。當模型對所有訓練資料都進行了預測後,一個迭代週期就完成了。

迭代週期是訓練過程中的一個重要引數,因為它會顯著影響模型的效能。迭代週期設定得太低會導致模型欠擬合,而設定得太高則會導致過擬合。

欠擬合是指模型未能捕捉資料中的潛在模式,在訓練集和測試集上的表現都很差。當模型過於簡單或訓練不足時就會發生這種情況。在這種情況下,增加迭代週期可以幫助模型從資料中學習更多資訊並提高其效能。

另一方面,過擬合是指模型學習了訓練資料中的噪聲,在訓練集上表現良好,但在測試集上表現很差。當模型過於複雜或訓練迭代週期過多時就會發生這種情況。為了避免過擬合,必須限制迭代週期的數量,並使用其他正則化技術,例如提前停止或 dropout。

Python 實現

在 Python 中,迭代週期的數量是在機器學習模型的訓練迴圈中指定的。例如,當使用 Keras 庫訓練神經網路時,可以使用 "fit" 方法中的 "epochs" 引數設定迭代週期的數量。

示例

# import necessary libraries
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

# generate some random data for training
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, size=(100,))

# create a neural network model
model = Sequential()
model.add(Dense(16, input_dim=10, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# compile the model with binary cross-entropy loss and adam optimizer
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# train the model with 10 epochs
model.fit(X_train, y_train, epochs=10)

在這個例子中,我們生成一些隨機訓練資料,並建立一個簡單的神經網路模型,該模型包含一個輸入層、一個隱藏層和一個輸出層。我們使用二元交叉熵損失和 Adam 最佳化器編譯模型,並在 "fit" 方法中將迭代週期數設定為 10。

在訓練過程中,模型對訓練資料進行預測,計算損失,並更新權重以最小化損失。完成 10 個迭代週期後,模型被認為已完成訓練,我們可以使用它對新的、未見過的資料進行預測。

輸出

執行此程式碼時,將產生類似這樣的輸出:

Epoch 1/10
4/4 [==============================] - 31s 2ms/step - loss: 0.7012 - accuracy: 0.4976
Epoch 2/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6995 - accuracy: 0.4390
Epoch 3/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6921 - accuracy: 0.5123
Epoch 4/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6778 - accuracy: 0.5474
Epoch 5/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6819 - accuracy: 0.5542
Epoch 6/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6795 - accuracy: 0.5377
Epoch 7/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6840 - accuracy: 0.5303
Epoch 8/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6795 - accuracy: 0.5554
Epoch 9/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6706 - accuracy: 0.5545
Epoch 10/10
4/4 [==============================] - 0s 1ms/step - loss: 0.6722 - accuracy: 0.5556
廣告