Apache MXNet - KVStore 和視覺化



本章介紹 Python 包 KVStore 和視覺化。

KVStore 包

KVStore 代表鍵值儲存。它是多裝置訓練中使用的關鍵元件。它很重要,因為在單機或多機上,引數在裝置之間的通訊是透過一個或多個帶有引數 KVStore 的伺服器傳輸的。

讓我們透過以下幾點來了解 KVStore 的工作原理

  • KVStore 中的每個值都由一個和一個表示。

  • 網路中的每個引數陣列都分配一個,而該引數陣列的權重由表示。

  • 之後,工作節點在處理完一個批次後推送梯度。它們還在處理新批次之前拉取更新後的權重。

簡單來說,我們可以說 KVStore 是一個數據共享的地方,每個裝置都可以將資料推入和拉出。

資料推入和拉出

KVStore 可以被認為是跨不同裝置(如 GPU 和計算機)共享的單個物件,每個裝置都可以將資料推入和拉出。

以下是裝置需要遵循的將資料推入和拉出的實現步驟

實現步驟

初始化 - 第一步是初始化值。在本例中,我們將一個 (int, NDArray) 對初始化到 KVStore 中,然後將值拉出 -

import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())

輸出

這將產生以下輸出 -

[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]

推送、聚合和更新 - 初始化後,我們可以將具有相同形狀的新值推送到 KVStore 中的相同鍵中 -

kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())

輸出

輸出如下所示 -

[[8. 8. 8.]
 [8. 8. 8.]
 [8. 8. 8.]]

用於推送的資料可以儲存在任何裝置上,例如 GPU 或計算機。我們還可以將多個值推送到同一個鍵中。在這種情況下,KVStore 將首先對所有這些值求和,然後按如下方式推送聚合值 -

contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())

輸出

您將看到以下輸出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

對於每次應用的推送操作,KVStore 將把推送的值與已儲存的值合併。這將藉助更新器完成。這裡,預設更新器是 ASSIGN。

def update(key, input, stored):
   print("update on key: %d" % key)
   
   stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())

輸出

執行上述程式碼時,您應該看到以下輸出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

示例

kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())

輸出

以下是程式碼的輸出 -

update on key: 3
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

拉取 - 與推送一樣,我們也可以透過一次呼叫將值拉取到多個裝置上,如下所示 -

b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

輸出

輸出如下所示 -

[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

完整的實現示例

以下是完整的實現示例 -

import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
   print("update on key: %d" % key)
   stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

處理鍵值對

我們上面實現的所有操作都涉及單個鍵,但 KVStore 還提供了一個鍵值對列表的介面 -

對於單個裝置

以下是一個示例,演示了針對單個裝置的鍵值對列表的 KVStore 介面 -

keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())

輸出

您將收到以下輸出 -

update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
 [3. 3. 3.]
 [3. 3. 3.]]

對於多個裝置

以下是一個示例,演示了針對多個裝置的鍵值對列表的 KVStore 介面 -

b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())

輸出

您將看到以下輸出 -

update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
 [11. 11. 11.]
 [11. 11. 11.]]

視覺化包

視覺化包是 Apache MXNet 包,用於將神經網路 (NN) 表示為由節點和邊組成的計算圖。

視覺化神經網路

在下面的示例中,我們將使用mx.viz.plot_network來視覺化神經網路。以下是先決條件 -

先決條件

  • Jupyter notebook

  • Graphviz 庫

實現示例

在下面的示例中,我們將視覺化用於線性矩陣分解的示例 NN -

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50

# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)

# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)

# Visualize the network
mx.viz.plot_network(N_net)
廣告
© . All rights reserved.