如何使用TensorFlow儲存和載入MNIST資料集的權重?


TensorFlow是由Google提供的機器學習框架。它是一個開源框架,與Python結合使用,可以實現演算法、深度學習應用程式等等。它被用於研究和生產環境。它具有最佳化技術,有助於快速執行復雜的數學運算。這是因為它使用了NumPy和多維陣列。這些多維陣列也被稱為“張量”。

可以使用以下程式碼行在Windows上安裝“tensorflow”包:

pip install tensorflow

張量是TensorFlow中使用的資料結構。它有助於連線資料流圖中的邊。這個資料流圖被稱為“資料流圖”。張量只不過是多維陣列或列表。

當訓練時間過長時,模型往往會過擬合,並且在測試資料上的泛化能力較差。因此,訓練步驟的數量必須保持平衡。這意味著,必須考慮所有資料案例才能進行有效的訓練。這樣,模型在測試資料上的泛化能力更好。否則,可以進行正則化。

Keras是一個用Python編寫的深度學習API。它是一個高階API,具有高效的介面,可以幫助解決機器學習問題。它執行在TensorFlow框架之上。它的構建是為了幫助快速實驗。它提供了開發和封裝機器學習解決方案所必需的基本抽象和構建塊。

Keras已經存在於TensorFlow包中。可以使用以下程式碼行訪問它。

import tensorflow
from tensorflow import keras

我們使用Google Colaboratory執行以下程式碼。Google Colab或Colaboratory有助於在瀏覽器上執行Python程式碼,無需任何配置,並且可以免費訪問GPU(圖形處理單元)。Colaboratory構建在Jupyter Notebook之上。以下是程式碼片段:

示例

!pip install -q pyyaml h5py
import os

import tensorflow as tf
from tensorflow import keras

print("The version of Tensorflow is : ")
print(tf.version.VERSION)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
print("Splitting training and test data")
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

print("Reshaping the training and test data")
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

程式碼來源: https://www.tensorflow.org/tutorials/keras/save_and_load

輸出

解釋

  • 匯入所需的包併為其設定別名。

  • 獲取前1000個示例以提高執行速度。

更新於:2021年1月20日

瀏覽量:122

啟動您的職業生涯

完成課程獲得認證

開始學習
廣告