如何使用TensorFlow和預訓練網路,即遷移學習來載入資料?


TensorFlow可以使用預訓練網路進行遷移學習來載入資料,方法是使用Keras包中的`get_file`方法。Google API儲存資料集,可以將其作為引數傳遞給`get_file`方法,以便將資料集下載到當前環境。

閱讀更多: 什麼是TensorFlow以及Keras如何與TensorFlow一起建立神經網路?

我們將瞭解如何藉助來自預訓練網路的遷移學習來對貓和狗的影像進行分類。

影像分類遷移學習背後的直覺是,如果一個模型在大型通用資料集上進行訓練,則該模型可以有效地用作視覺世界的通用模型。它已經學習了特徵圖,這意味著使用者不必從頭開始在一個大型資料集上訓練大型模型。

閱讀更多: 如何預訓練定製模型?

我們使用Google Colaboratory執行以下程式碼。Google Colab或Colaboratory幫助在瀏覽器上執行Python程式碼,無需任何配置,並可免費訪問GPU(圖形處理單元)。Colaboratory構建在Jupyter Notebook之上。

示例

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
print("Downloading the data")
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
BATCH_SIZE = 32
IMG_SIZE = (160, 160)

程式碼來源 −https://www.tensorflow.org/tutorials/images/transfer_learning

輸出

Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 1s 0us/step
Downloading the data

解釋

  • 在這個資料集中,我們有數千張貓和狗的影像。

  • 它們已下載並從zip檔案中解壓。

  • 建立了一個`tf.data.Dataset`,用於訓練和驗證。

  • 這是藉助`tf.keras.preprocessing.image_dataset_from_directory`實用程式完成的。

更新於: 2021年2月25日

85 次瀏覽

啟動您的職業生涯

完成課程獲得認證

開始學習
廣告