Caffe2 - 建立你自己的網路



在本課中,你將學習如何在 Caffe2 中定義一個單層神經網路 (NN),並在隨機生成的資料集上執行它。我們將編寫程式碼來圖形化地描繪網路架構,列印輸入、輸出、權重和偏差值。為了理解本課,你必須熟悉神經網路架構、其術語和其中使用的數學

網路架構

讓我們假設我們想構建一個如下圖所示的單層神經網路:

Network Architecture

從數學上講,這個網路由以下 Python 程式碼表示:

Y = X * W^T + b

其中X,W,b是張量,Y是輸出。我們將用一些隨機資料填充所有三個張量,執行網路並檢查Y輸出。為了定義網路和張量,Caffe2 提供了幾個運算元函式。

Caffe2 運算元

在 Caffe2 中,運算元是計算的基本單元。Caffe2 運算元表示如下。

Caffe2 Operators

Caffe2 提供了詳盡的運算元列表。對於我們目前正在設計的網路,我們將使用稱為 FC 的運算元,它計算將輸入向量X傳遞到具有二維權重矩陣W和一維偏差向量b的全連線網路的結果。換句話說,它計算以下數學方程

Y = X * W^T + b

其中X的維度為(M x k)W的維度為(n x k)b(1 x n)。輸出Y的維度將為(M x n),其中M是批次大小。

對於向量XW,我們將使用GaussianFill運算元來建立一些隨機資料。為了生成偏差值b,我們將使用ConstantFill運算元。

我們現在將繼續定義我們的網路。

建立網路

首先,匯入所需的包:

from caffe2.python import core, workspace

接下來,透過呼叫core.Net來定義網路,如下所示:

net = core.Net("SingleLayerFC")

網路的名稱指定為SingleLayerFC。此時,名為 net 的網路物件被建立。到目前為止,它不包含任何層。

建立張量

我們現在將建立網路所需的三個向量。首先,我們將透過呼叫GaussianFill運算元來建立 X 張量,如下所示:

X = net.GaussianFill([], ["X"], mean=0.0, std=1.0, shape=[2, 3], run_once=0)

X向量的維度為2 x 3,平均資料值為 0.0,標準差為1.0

同樣,我們建立W張量,如下所示:

W = net.GaussianFill([], ["W"], mean=0.0, std=1.0, shape=[5, 3], run_once=0)

W向量的尺寸為5 x 3

最後,我們建立大小為 5 的偏差b矩陣。

b = net.ConstantFill([], ["b"], shape=[5,], value=1.0, run_once=0)

現在,程式碼中最重要的部分來了,那就是定義網路本身。

定義網路

我們在以下 Python 語句中定義網路:

Y = X.FC([W, b], ["Y"])

我們在輸入資料X上呼叫FC運算元。權重在W中指定,偏差在 b 中指定。輸出是Y。或者,你可以使用以下更詳細的 Python 語句建立網路。

Y = net.FC([X, W, b], ["Y"])

此時,網路只是被建立了。在我們至少執行一次網路之前,它不包含任何資料。在執行網路之前,我們將檢查其架構。

列印網路架構

Caffe2 在 JSON 檔案中定義網路架構,可以透過在建立的net物件上呼叫 Proto 方法來檢查。

print (net.Proto())

這將產生以下輸出:

name: "SingleLayerFC"
op {
   output: "X"
   name: ""
   type: "GaussianFill"
   arg {
      name: "mean"
      f: 0.0
   }
   arg {
      name: "std"
      f: 1.0
   }
   arg {
      name: "shape"
      ints: 2
      ints: 3
   }
   arg {
      name: "run_once"
      i: 0
   }
}
op {
   output: "W"
   name: ""
   type: "GaussianFill"
   arg {
      name: "mean"
      f: 0.0
   }
   arg {
      name: "std"
      f: 1.0
   }
   arg {
      name: "shape"
      ints: 5
      ints: 3
   }
   arg {
      name: "run_once"
      i: 0
   }
}
op {
   output: "b"
   name: ""
   type: "ConstantFill"
   arg {
      name: "shape"
      ints: 5
   }
   arg {
      name: "value"
      f: 1.0
   }
   arg {
      name: "run_once"
      i: 0
   }
}
op {
   input: "X"
   input: "W"
   input: "b"
   output: "Y"
   name: ""
   type: "FC"
}

正如你在上面的列表中看到的,它首先定義了運算元X,Wb。讓我們以W的定義為例進行檢查。W的型別指定為GausianFill均值定義為浮點數0.0,標準差定義為浮點數1.0形狀5 x 3

op {
   output: "W"
   name: "" type: "GaussianFill"
   arg {
      name: "mean" 
	   f: 0.0
   }
   arg { 
      name: "std" 
      f: 1.0
   }
   arg { 
      name: "shape" 
      ints: 5 
      ints: 3
   }
   ...
}

檢查Xb的定義以加深你的理解。最後,讓我們看看我們單層網路的定義,這裡將其複製如下

op {
   input: "X"
   input: "W"
   input: "b"
   output: "Y"
   name: ""
   type: "FC"
}

在這裡,網路型別為FC(全連線),輸入為X,W,b,輸出為Y。這種網路定義過於冗長,對於大型網路,檢查其內容將變得乏味。幸運的是,Caffe2 為建立的網路提供了圖形表示。

網路圖形表示

要獲得網路的圖形表示,請執行以下程式碼片段,它實際上只有兩行 Python 程式碼。

from caffe2.python import net_drawer
from IPython import display
graph = net_drawer.GetPydotGraph(net, rankdir="LR")
display.Image(graph.create_png(), width=800)

執行程式碼時,你會看到以下輸出:

Graphical Representation

對於大型網路,圖形表示在視覺化和除錯網路定義錯誤方面變得非常有用。

最後,現在是執行網路的時候了。

執行網路

你可以透過在workspace物件上呼叫RunNetOnce方法來執行網路:

workspace.RunNetOnce(net)

網路執行一次後,我們隨機生成的所有資料都將被建立,輸入網路,並建立輸出。執行網路後建立的張量在 Caffe2 中稱為blobs。工作區包含你建立並存儲在記憶體中的blobs。這與 Matlab 非常相似。

執行網路後,你可以使用以下print命令檢查工作區包含的blobs

print("Blobs in the workspace: {}".format(workspace.Blobs()))

你會看到以下輸出:

Blobs in the workspace: ['W', 'X', 'Y', 'b']

請注意,工作區包含三個輸入 blobs——X,Wb。它還包含名為Y的輸出 blob。現在讓我們檢查這些 blobs 的內容。

for name in workspace.Blobs():
   print("{}:\n{}".format(name, workspace.FetchBlob(name)))

你會看到以下輸出:

W:
[[ 1.0426593 0.15479846 0.25635982]
[-2.2461145 1.4581774 0.16827184]
[-0.12009818 0.30771437 0.00791338]
[ 1.2274994 -0.903331 -0.68799865]
[ 0.30834186 -0.53060573 0.88776857]]
X:
[[ 1.6588869e+00 1.5279824e+00 1.1889904e+00]
[ 6.7048723e-01 -9.7490678e-04 2.5114202e-01]]
Y:
[[ 3.2709925 -0.297907 1.2803618 0.837985 1.7562964]
[ 1.7633215 -0.4651525 0.9211631 1.6511179 1.4302125]]
b:
[1. 1. 1. 1. 1.]

請注意,你機器上的資料,或者事實上網路的每次執行的資料都會有所不同,因為所有輸入都是隨機建立的。你現在已成功定義了一個網路並在你的計算機上執行它。

廣告
© . All rights reserved.