如何使用 Keras 在 Python 中建立回撥函式並儲存權重?


Tensorflow 是 Google 提供的一個機器學習框架。它是一個開源框架,與 Python 結合使用以實現演算法、深度學習應用程式等等。它用於研究和生產目的。它具有有助於快速執行復雜數學運算的最佳化技術。這是因為它使用了 NumPy 和多維陣列。這些多維陣列也稱為“張量”。

可以使用以下程式碼行在 Windows 上安裝“tensorflow”包:

pip install tensorflow

張量是 TensorFlow 中使用的資料結構。它有助於連線流圖中的邊。此流圖稱為“資料流圖”。張量不過是多維陣列或列表。

Keras 是作為 ONEIROS(開放式神經電子智慧機器人作業系統)專案研究的一部分開發的。Keras 是一個用 Python 編寫的深度學習 API。它是一個高階 API,具有一個高效的介面,有助於解決機器學習問題。它執行在 Tensorflow 框架之上。它的構建是為了幫助快速進行實驗。它提供了開發和封裝機器學習解決方案所必需的基本抽象和構建塊。

Keras 已經存在於 Tensorflow 包中。可以使用以下程式碼行訪問它。

import tensorflow
from tensorflow import keras

我們使用 Google Colaboratory 來執行以下程式碼。Google Colab 或 Colaboratory 幫助透過瀏覽器執行 Python 程式碼,無需任何配置,並且可以免費訪問 GPU(圖形處理單元)。Colaboratory 構建在 Jupyter Notebook 之上。以下是程式碼:

示例

print("Set checkpoint path")
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

print("Creating a callback to save the weights")
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

print("Model is trained with new callback")
model.fit(train_images,
   train_labels,
   epochs=10,
   validation_data=(test_images, test_labels),
   callbacks=[cp_callback])
ls {checkpoint_dir}

程式碼來源 -  https://www.tensorflow.org/tutorials/keras/save_and_load

輸出

解釋

  • 訓練好的模型可以在不重新訓練或從其停止點開始訓練的情況下使用。

  • “ModelCheckpoint”方法在訓練期間和訓練結束時持續儲存模型。

  • 這樣,檢查點檔案在每個 epoch 之後都會更新。

  • 此模型適合訓練資料。

更新於: 20-Jan-2021

81 次瀏覽

開啟您的 職業生涯

透過完成課程獲得認證

開始學習
廣告