Apache MXNet - KVStore 和可视化



本章介绍 Python 包 KVStore 和可视化。

KVStore 包

KVStore 代表键值存储。它是多设备训练中使用的关键组件。它很重要,因为在单机或多机上,参数在设备之间的通信是通过一个或多个带有参数 KVStore 的服务器传输的。

让我们通过以下几点来了解 KVStore 的工作原理

  • KVStore 中的每个值都由一个和一个表示。

  • 网络中的每个参数数组都分配一个,而该参数数组的权重由表示。

  • 之后,工作节点在处理完一个批次后推送梯度。它们还在处理新批次之前拉取更新后的权重。

简单来说,我们可以说 KVStore 是一个数据共享的地方,每个设备都可以将数据推入和拉出。

数据推入和拉出

KVStore 可以被认为是跨不同设备(如 GPU 和计算机)共享的单个对象,每个设备都可以将数据推入和拉出。

以下是设备需要遵循的将数据推入和拉出的实现步骤

实现步骤

初始化 - 第一步是初始化值。在本例中,我们将一个 (int, NDArray) 对初始化到 KVStore 中,然后将值拉出 -

import mxnet as mx
kv = mx.kv.create('local') # create a local KVStore.
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())

输出

这将产生以下输出 -

[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]

推送、聚合和更新 - 初始化后,我们可以将具有相同形状的新值推送到 KVStore 中的相同键中 -

kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a)
print(a.asnumpy())

输出

输出如下所示 -

[[8. 8. 8.]
 [8. 8. 8.]
 [8. 8. 8.]]

用于推送的数据可以存储在任何设备上,例如 GPU 或计算机。我们还可以将多个值推送到同一个键中。在这种情况下,KVStore 将首先对所有这些值求和,然后按如下方式推送聚合值 -

contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())

输出

您将看到以下输出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

对于每次应用的推送操作,KVStore 将把推送的值与已存储的值合并。这将借助更新器完成。这里,默认更新器是 ASSIGN。

def update(key, input, stored):
   print("update on key: %d" % key)
   
   stored += input * 2
kv.set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())

输出

执行上述代码时,您应该看到以下输出 -

[[4. 4. 4.]
 [4. 4. 4.]
 [4. 4. 4.]]

示例

kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())

输出

以下是代码的输出 -

update on key: 3
[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

拉取 - 与推送一样,我们也可以通过一次调用将值拉取到多个设备上,如下所示 -

b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

输出

输出如下所示 -

[[6. 6. 6.]
 [6. 6. 6.]
 [6. 6. 6.]]

完整的实现示例

以下是完整的实现示例 -

import mxnet as mx
kv = mx.kv.create('local')
shape = (3,3)
kv.init(3, mx.nd.ones(shape)*2)
a = mx.nd.zeros(shape)
kv.pull(3, out = a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape)*8)
kv.pull(3, out = a) # pull out the value
print(a.asnumpy())
contexts = [mx.cpu(i) for i in range(4)]
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.push(3, b)
kv.pull(3, out = a)
print(a.asnumpy())
def update(key, input, stored):
   print("update on key: %d" % key)
   stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a.asnumpy())
kv.push(3, mx.nd.ones(shape))
kv.pull(3, out=a)
print(a.asnumpy())
b = [mx.nd.ones(shape, ctx) for ctx in contexts]
kv.pull(3, out = b)
print(b[1].asnumpy())

处理键值对

我们上面实现的所有操作都涉及单个键,但 KVStore 还提供了一个键值对列表的接口 -

对于单个设备

以下是一个示例,演示了针对单个设备的键值对列表的 KVStore 接口 -

keys = [5, 7, 9]
kv.init(keys, [mx.nd.ones(shape)]*len(keys))
kv.push(keys, [mx.nd.ones(shape)]*len(keys))
b = [mx.nd.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1].asnumpy())

输出

您将收到以下输出 -

update on key: 5
update on key: 7
update on key: 9
[[3. 3. 3.]
 [3. 3. 3.]
 [3. 3. 3.]]

对于多个设备

以下是一个示例,演示了针对多个设备的键值对列表的 KVStore 接口 -

b = [[mx.nd.ones(shape, ctx) for ctx in contexts]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1].asnumpy())

输出

您将看到以下输出 -

update on key: 5
update on key: 7
update on key: 9
[[11. 11. 11.]
 [11. 11. 11.]
 [11. 11. 11.]]

可视化包

可视化包是 Apache MXNet 包,用于将神经网络 (NN) 表示为由节点和边组成的计算图。

可视化神经网络

在下面的示例中,我们将使用mx.viz.plot_network来可视化神经网络。以下是先决条件 -

先决条件

  • Jupyter notebook

  • Graphviz 库

实现示例

在下面的示例中,我们将可视化用于线性矩阵分解的示例 NN -

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set the dummy dimensions
k = 64
max_user = 100
max_item = 50

# The user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# The item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product and then do sum
N_net = user * item
N_net = mx.symbol.sum_axis(data = N_net, axis = 1)
N_net = mx.symbol.Flatten(data = N_net)

# Defining the loss layer
N_net = mx.symbol.LinearRegressionOutput(data = N_net, label = score)

# Visualize the network
mx.viz.plot_network(N_net)
广告