mxnet.kvstore

Key value store interface of MXNet for parameter synchronization.

Classes

KVStore(handle)

A key-value store for synchronization of values, over multiple devices.

Functions

create([name])

Creates a new KVStore.

class mxnet.kvstore.KVStore(handle)[source]

Bases: object

A key-value store for synchronization of values, over multiple devices.

Methods

init(key, value)

Initializes a single or a sequence of key-value pairs into the store.

load_optimizer_states(fname)

Loads the optimizer (updater) state from the file.

pull(key[, out, priority, ignore_sparse])

Pulls a single value or a sequence of values from the store.

push(key, value[, priority])

Pushes a single or a sequence of key-value pairs into the store.

pushpull(key, value[, out, priority])

Performs push and pull a single value or a sequence of values from the store.

row_sparse_pull(key[, out, priority, row_ids])

Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values from the store with specified row_ids.

save_optimizer_states(fname[, dump_optimizer])

Saves the optimizer (updater) state to a file.

set_gradient_compression(compression_params)

Specifies type of low-bit quantization for gradient compression and additional arguments depending on the type of compression being used.

set_optimizer(optimizer)

Registers an optimizer with the kvstore.

Attributes

num_workers

Returns the number of worker nodes.

rank

Returns the rank of this worker node.

type

Returns the type of this kvstore.

init(key, value)[source]

Initializes a single or a sequence of key-value pairs into the store.

For each key, one must init it before calling push or pull. When multiple workers invoke init for the same key, only the value supplied by worker with rank 0 is used. This function returns after data has been initialized successfully.

Parameters

Examples

>>> # init a single key-value pair
>>> shape = (2,3)
>>> kv = mx.kv.create('local')
>>> kv.init('3', mx.nd.ones(shape)*2)
>>> a = mx.nd.zeros(shape)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 2.  2.  2.]
[ 2.  2.  2.]]
>>> # init a list of key-value pairs
>>> keys = ['5', '7', '9']
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
>>> # init a row_sparse value
>>> kv.init('4', mx.nd.ones(shape).tostype('row_sparse'))
>>> b = mx.nd.sparse.zeros('row_sparse', shape)
>>> kv.row_sparse_pull('4', row_ids=mx.nd.array([0, 1]), out=b)
>>> print b
<RowSparseNDArray 2x3 @cpu(0)>
load_optimizer_states(fname)[source]

Loads the optimizer (updater) state from the file.

Parameters

fname (str) – Path to input states file.

property num_workers

Returns the number of worker nodes.

Returns

size – The number of worker nodes.

Return type

int

pull(key, out=None, priority=0, ignore_sparse=True)[source]

Pulls a single value or a sequence of values from the store.

This function returns immediately after adding an operator to the engine. Subsequent attempts to read from the out variable will be blocked until the pull operation completes.

pull is executed asynchronously after all previous pull calls and only the last push call for the same input key(s) are finished.

The returned values are guaranteed to be the latest values in the store.

pull with RowSparseNDArray is not supported for dist kvstore. Please use row_sparse_pull instead.

Parameters
  • key (str, int, or sequence of str or int) – Keys.

  • out (NDArray or list of NDArray or list of list of NDArray) – Values corresponding to the keys.

  • priority (int, optional) – The priority of the pull operation. Higher priority pull operations are likely to be executed before other pull actions.

  • ignore_sparse (bool, optional, default True) – Whether to ignore sparse arrays in the request.

Examples

>>> # pull a single key-value pair
>>> a = mx.nd.zeros(shape)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 2.  2.  2.]
[ 2.  2.  2.]]
>>> # pull into multiple devices
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pull('3', out=b)
>>> print b[1].asnumpy()
[[ 2.  2.  2.]
[ 2.  2.  2.]]
>>> # pull a list of key-value pairs.
>>> # On single device
>>> keys = ['5', '7', '9']
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
[[ 2.  2.  2.]
[ 2.  2.  2.]]
>>> # On multiple devices
>>> keys = ['6', '8', '10']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1][1].asnumpy()
[[ 2.  2.  2.]
[ 2.  2.  2.]]
push(key, value, priority=0)[source]

Pushes a single or a sequence of key-value pairs into the store.

This function returns immediately after adding an operator to the engine. The actual operation is executed asynchronously. If there are consecutive pushes to the same key, there is no guarantee on the serialization of pushes. The execution of a push does not guarantee that all previous pushes are finished. There is no synchronization between workers. One can use _barrier() to sync all workers.

Parameters
  • key (str, int, or sequence of str or int) – Keys.

  • value (NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray,) – or list of list of NDArray or RowSparseNDArray Values corresponding to the keys.

  • priority (int, optional) – The priority of the push operation. Higher priority push operations are likely to be executed before other push actions.

Examples

>>> # push a single key-value pair
>>> kv.push('3', mx.nd.ones(shape)*8)
>>> kv.pull('3', out=a) # pull out the value
>>> print a.asnumpy()
[[ 8.  8.  8.]
[ 8.  8.  8.]]
>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push('3', b)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 4.  4.  4.]
[ 4.  4.  4.]]
>>> # push a list of keys.
>>> # single device
>>> keys = ['4', '5', '6']
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
[[ 1.  1.  1.]
[ 1.  1.  1.]]
>>> # multiple devices:
>>> keys = ['7', '8', '9']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.push(keys, b)
>>> kv.pull(keys, out=b)
>>> print b[1][1].asnumpy()
[[ 4.  4.  4.]
[ 4.  4.  4.]]
>>> # push a row_sparse value
>>> b = mx.nd.sparse.zeros('row_sparse', shape)
>>> kv.init('10', mx.nd.sparse.zeros('row_sparse', shape))
>>> kv.push('10', mx.nd.ones(shape).tostype('row_sparse'))
>>> # pull out the value
>>> kv.row_sparse_pull('10', row_ids=mx.nd.array([0, 1]), out=b)
>>> print b
<RowSparseNDArray 2x3 @cpu(0)>
pushpull(key, value, out=None, priority=0)[source]

Performs push and pull a single value or a sequence of values from the store.

This function is coalesced form of push and pull operations. This function returns immediately after adding an operator to the engine. Subsequent attempts to read from the out variable will be blocked until the pull operation completes.

value is pushed to the kvstore server for the specified keys and the updated values are pulled from the server to out. If out is not specified the pulled values are written to value. The returned values are guaranteed to be the latest values in the store.

pushpull with RowSparseNDArray is not supported for dist kvstore.

Parameters
  • key (str, int, or sequence of str or int) – Keys.

  • value (NDArray, RowSparseNDArray, list of NDArray or RowSparseNDArray,) – or list of list of NDArray or RowSparseNDArray Values corresponding to the keys.

  • out (NDArray or list of NDArray or list of list of NDArray) – Values corresponding to the keys.

  • priority (int, optional) – The priority of the pull operation. Higher priority pull operations are likely to be executed before other pull actions.

Examples

>>> # push a single key-value pair
>>> kv.pushpull('3', mx.nd.ones(shape)*8, out=a)
>>> print a.asnumpy()
[[ 8.  8.  8.]
[ 8.  8.  8.]]
>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pushpull('3', b, out=a)
>>> print a.asnumpy()
[[ 4.  4.  4.]
[ 4.  4.  4.]]
>>> # push a list of keys.
>>> # single device
>>> keys = ['4', '5', '6']
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys), out=b)
>>> print b[1].asnumpy()
[[ 1.  1.  1.]
[ 1.  1.  1.]]
>>> # multiple devices:
>>> keys = ['7', '8', '9']
>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.pushpull(keys, b)
>>> print b[1][1].asnumpy()
[[ 4.  4.  4.]
[ 4.  4.  4.]]
property rank

Returns the rank of this worker node.

Returns

rank – The rank of this node, which is in range [0, num_workers())

Return type

int

row_sparse_pull(key, out=None, priority=0, row_ids=None)[source]

Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull is invoked just once and the result is broadcast to all the rest of outputs.

row_sparse_pull is executed asynchronously after all previous pull/row_sparse_pull calls and the last push call for the same input key(s) are finished.

The returned values are guaranteed to be the latest values in the store.

Parameters
  • key (str, int, or sequence of str or int) – Keys.

  • out (RowSparseNDArray or list of RowSparseNDArray or list of list of RowSparseNDArray) – Values corresponding to the keys. The stype is expected to be row_sparse

  • priority (int, optional) – The priority of the pull operation. Higher priority pull operations are likely to be executed before other pull actions.

  • row_ids (NDArray or list of NDArray) – The row_ids for which to pull for each value. Each row_id is an 1-D NDArray whose values don’t have to be unique nor sorted.

Examples

>>> shape = (3, 3)
>>> kv.init('3', mx.nd.ones(shape).tostype('row_sparse'))
>>> a = mx.nd.sparse.zeros('row_sparse', shape)
>>> row_ids = mx.nd.array([0, 2], dtype='int64')
>>> kv.row_sparse_pull('3', out=a, row_ids=row_ids)
>>> print a.asnumpy()
[[ 1.  1.  1.]
[ 0.  0.  0.]
[ 1.  1.  1.]]
>>> duplicate_row_ids = mx.nd.array([2, 2], dtype='int64')
>>> kv.row_sparse_pull('3', out=a, row_ids=duplicate_row_ids)
>>> print a.asnumpy()
[[ 0.  0.  0.]
[ 0.  0.  0.]
[ 1.  1.  1.]]
>>> unsorted_row_ids = mx.nd.array([1, 0], dtype='int64')
>>> kv.row_sparse_pull('3', out=a, row_ids=unsorted_row_ids)
>>> print a.asnumpy()
[[ 1.  1.  1.]
[ 1.  1.  1.]
[ 0.  0.  0.]]
save_optimizer_states(fname, dump_optimizer=False)[source]

Saves the optimizer (updater) state to a file. This is often used when checkpointing the model during training.

Parameters
  • fname (str) – Path to the output states file.

  • dump_optimizer (bool, default False) – Whether to also save the optimizer itself. This would also save optimizer information such as learning rate and weight decay schedules.

set_gradient_compression(compression_params)[source]

Specifies type of low-bit quantization for gradient compression and additional arguments depending on the type of compression being used.

2bit Gradient Compression takes a positive float threshold. The technique works by thresholding values such that positive values in the gradient above threshold will be set to threshold. Negative values whose absolute values are higher than threshold, will be set to the negative of threshold. Values whose absolute values are less than threshold will be set to 0. By doing so, each value in the gradient is in one of three states. 2bits are used to represent these states, and every 16 float values in the original gradient can be represented using one float. This compressed representation can reduce communication costs. The difference between these thresholded values and original values is stored at the sender’s end as residual and added to the gradient in the next iteration.

When kvstore is ‘local’, gradient compression is used to reduce communication between multiple devices (gpus). Gradient is quantized on each GPU which computed the gradients, then sent to the GPU which merges the gradients. This receiving GPU dequantizes the gradients and merges them. Note that this increases memory usage on each GPU because of the residual array stored.

When kvstore is ‘dist’, gradient compression is used to reduce communication from worker to sender. Gradient is quantized on each worker which computed the gradients, then sent to the server which dequantizes this data and merges the gradients from each worker. Note that this increases CPU memory usage on each worker because of the residual array stored. Only worker to server communication is compressed in this setting. If each machine has multiple GPUs, currently this GPU to GPU or GPU to CPU communication is not compressed. Server to worker communication (in the case of pull) is also not compressed.

To use 2bit compression, we need to specify type as 2bit. Only specifying type would use default value for the threshold. To completely specify the arguments for 2bit compression, we would need to pass a dictionary which includes threshold like: {‘type’: ‘2bit’, ‘threshold’: 0.5}

Parameters

compression_params (dict) – A dictionary specifying the type and parameters for gradient compression. The key type in this dictionary is a required string argument and specifies the type of gradient compression. Currently type can be only 2bit Other keys in this dictionary are optional and specific to the type of gradient compression.

set_optimizer(optimizer)[source]

Registers an optimizer with the kvstore.

When using a single machine, this function updates the local optimizer. If using multiple machines and this operation is invoked from a worker node, it will serialized the optimizer with pickle and send it to all servers. The function returns after all servers have been updated.

Parameters

optimizer (Optimizer) – The new optimizer for the store

Examples

>>> kv = mx.kv.create()
>>> shape = (2, 2)
>>> weight = mx.nd.zeros(shape)
>>> kv.init(3, weight)
>>> # set the optimizer for kvstore as the default SGD optimizer
>>> kv.set_optimizer(mx.optimizer.SGD())
>>> grad = mx.nd.ones(shape)
>>> kv.push(3, grad)
>>> kv.pull(3, out = weight)
>>> # weight is updated via gradient descent
>>> weight.asnumpy()
array([[-0.01, -0.01],
       [-0.01, -0.01]], dtype=float32)
property type

Returns the type of this kvstore.

Returns

type – the string type

Return type

str

mxnet.kvstore.create(name='local')[source]

Creates a new KVStore.

For single machine training, there are two commonly used types:

local: Copies all gradients to CPU memory and updates weights there.

device: Aggregates gradients and updates weights on GPUs. With this setting, the KVStore also attempts to use GPU peer-to-peer communication, potentially accelerating the communication.

For distributed training, KVStore also supports a number of types:

dist_sync: Behaves similarly to local but with one major difference. With dist_sync, batch-size now means the batch size used on each machine. So if there are n machines and we use batch size b, then dist_sync behaves like local with batch size n * b.

dist_device_sync: Identical to dist_sync with the difference similar to device vs local.

dist_async: Performs asynchronous updates. The weights are updated whenever gradients are received from any machine. No two updates happen on the same weight at the same time. However, the order is not guaranteed.

Parameters

name ({'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async'}) – The type of KVStore.

Returns

kv – The created KVStore.

Return type

KVStore