如何使用 Python 在 Tensorflow 中新增密集層?
可以使用“add”方法將密集層新增到順序模型中,並將層型別指定為“Dense”。首先將層展平,然後新增一層。此新層將應用於整個訓練資料集。
閱讀更多: 什麼是 TensorFlow 以及 Keras 如何與 TensorFlow 協作建立神經網路?
我們將使用 Keras 順序 API,它有助於構建一個順序模型,該模型用於處理簡單的層堆疊,其中每一層只有一個輸入張量和一個輸出張量。
我們正在使用 Google Colaboratory 來執行以下程式碼。Google Colab 或 Colaboratory 幫助在瀏覽器上執行 Python 程式碼,無需任何配置,並可免費訪問 GPU(圖形處理單元)。Colaboratory 建立在 Jupyter Notebook 之上。
print("Adding dense layer on top")
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))
print("Complete architecture of the model")
model.summary()程式碼來源:https://www.tensorflow.org/tutorials/images/cnn
輸出
Adding dense layer on top Complete architecture of the model Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_3 (Conv2D) (None, 30, 30, 32) 896 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 15, 15, 32) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 13, 13, 64) 18496 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 6, 6, 64) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 4, 4, 64) 36928 _________________________________________________________________ flatten (Flatten) (None, 1024) 0 _________________________________________________________________ dense (Dense) (None, 64) 65600 _________________________________________________________________ dense_1 (Dense) (None, 10) 650 ================================================================= Total params: 122,570 Trainable params: 122,570 Non-trainable params: 0 _________________________________________________________________
解釋
- 為了完成模型,將來自卷積基的最後一個輸出張量(形狀為 (4, 4, 64))饋送到一個或多個密集層以執行分類。
- 密集層將以向量作為輸入(為 1D),而當前輸出為 3D 張量。
- 接下來,將 3D 輸出展平為 1D,並在其上新增一個或多個密集層。
- CIFAR 有 10 個輸出類別,因此添加了一個具有 10 個輸出的最終密集層。
- 在經過兩個密集層之前,(4, 4, 64) 輸出被展平為形狀為 (1024) 的向量。
廣告
資料結構
網路
關係資料庫管理系統
作業系統
Java
iOS
HTML
CSS
Android
Python
C 程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP