SVRG Optimization in Python Module API

Overview

SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that was first introduced in paper Accelerating Stochastic Gradient Descent using Predictive Variance Reduction in 2013. It is complement to SGD (Stochastic Gradient Descent), which is known for large scale optimization but suffers from slow convergence asymptotically due to its inherent variance. SGD approximates the full gradients using a small batch of data or a single data sample, which will introduce variance and thus requires to start with a small learning rate in order to ensure convergence. SVRG remedies the problem by keeping track of a version of estimated weights that close to the optimal parameter values and maintaining an average of full gradients over a full pass of data. The average of full gradients is calculated with respect to the weights from the last m-th epochs in the training. SVRG uses a different update rule: gradients w.r.t current parameter values minus gradients w.r.t to parameters from the last m-th epochs plus the average of full gradients over all data.

Key Characteristics of SVRG:

  • Employs explicit variance reduction by using a different update rule compared to SGD.
  • Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
  • Guarantees for fast convergence for smooth and strongly convex functions.

SVRG optimization is implemented as a SVRGModule in mxnet.contrib.svrg_optimization, which is an extension of the existing mxnet.module.Module APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule API changes compared to Module API to end users are minimal.

In distributed training, each worker gets the same special weights from the last m-th epoch and calculates the full gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full gradients, which is calculated by aggregating the full gradients from each worker and averaging over the number of workers. The workaround is to keep an additional set of keys in the KVStore that maps to full gradients. The _SVRGOptimizer is designed to wrap two optimizers, an _AssignmentOptimizer which is used for full gradients accumulation in the KVStore and a regular optimizer that performs actual update rule to the parameters. The _SVRGOptimizer and _AssignmentOptimizer are designed to be used in SVRGModule only.

Warning

This package contains experimental APIs and may change in the near future.

This document lists the SVRGModule APIs in MXNet/Contrib package:

mxnet.contrib.svrg_optimization.svrg_module A SVRGModule implements the Module API by wrapping an auxiliary module to perform SVRG optimization logic.

Intermediate Level API for SVRGModule

The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate level APIs.

>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
>>> mod.init_params()
>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), kvstore='local')
>>> for epoch in range(num_epochs):
...     if epoch % mod.update_freq == 0:
...         mod.update_full_grads(di)
...     di.reset()
...     for batch in di:
...         mod.forward_backward(data_batch=batch)
...         mod.update()

High Level API for SVRGModule

The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of suggested usage of high level API.

>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))

API reference

A SVRGModule implements the Module API by wrapping an auxiliary module to perform SVRG optimization logic.

class mxnet.contrib.svrg_optimization.svrg_module.SVRGModule(symbol, data_names=('data', ), label_names=('softmax_label', ), logger=, context=cpu(0), work_load_list=None, fixed_param_names=None, state_names=None, group2ctxs=None, compression_params=None, update_freq=None)[source]

SVRGModule is a module that encapsulates two Modules to accommodate the SVRG optimization technique. It is functionally the same as Module API, except it is implemented using SVRG optimization logic.

Parameters:
  • symbol (Symbol) –
  • data_names (list of str) – Defaults to (‘data’) for a typical model used in image classification.
  • label_names (list of str) – Defaults to (‘softmax_label’) for a typical model used in image classification.
  • logger (Logger) – Defaults to logging.
  • context (Context or list of Context) – Defaults to mx.cpu().
  • work_load_list (list of number) – Default None, indicating uniform workload.
  • fixed_param_names (list of str) – Default None, indicating no network parameters are fixed.
  • state_names (list of str) – states are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states().
  • group2ctxs (dict of str to context or list of context, or list of dict of str to context) – Default is None. Mapping the ctx_group attribute to the context assignment.
  • compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:‘2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. update_freq: int Specifies the number of times to update the full gradients to be used in the SVRG optimization. For instance, update_freq = 2 will calculates the gradients over all data every two epochs

Examples

>>> # An example of declaring and using SVRGModule.
>>> mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2)
>>> mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),),
>>>         num_epoch=num_epoch, kvstore='local')
reshape(data_shapes, label_shapes=None)[source]

Reshapes both modules for new input shapes.

Parameters:
  • data_shapes (list of (str, tuple)) – Typically is data_iter.provide_data.
  • label_shapes (list of (str, tuple)) – Typically is data_iter.provide_label.
init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), force_init=False)[source]

Installs and initializes SVRGOptimizer. The SVRGOptimizer is a wrapper class for a regular optimizer that is passed in and a special AssignmentOptimizer to accumulate the full gradients. If KVStore is ‘local’ or None, the full gradients will be accumulated locally without pushing to the KVStore. Otherwise, additional keys will be pushed to accumulate the full gradients in the KVStore.

Parameters:
  • kvstore (str or KVStore) – Default ‘local’.
  • optimizer (str or Optimizer) – Default ‘sgd’
  • optimizer_params (dict) – Default ((‘learning_rate’, 0.01),). The default value is not a dictionary, just to avoid pylint warning of dangerous default values.
  • force_init (bool) – Default False, indicating whether we should force re-initializing the optimizer in the case an optimizer is already installed.
bind(data_shapes, label_shapes=None, for_training=True, inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write')[source]

Binds the symbols to construct executors for both two modules. This is necessary before one can perform computation with the SVRGModule.

Parameters:
  • data_shapes (list of (str, tuple)) – Typically is data_iter.provide_data.
  • label_shapes (list of (str, tuple)) – Typically is data_iter.provide_label.
  • for_training (bool) – Default is True. Whether the executors should be bound for training.
  • inputs_need_grad (bool) – Default is False. Whether the gradients to the input data need to be computed. Typically this is not needed. But this might be needed when implementing composition of modules.
  • force_rebind (bool) – Default is False. This function does nothing if the executors are already bound. But with this True, the executors will be forced to rebind.
  • shared_module (Module) – Default is None. This is used in bucketing. When not None, the shared module essentially corresponds to a different bucket – a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
forward(data_batch, is_train=None)[source]

Forward computation for both two modules. It supports data batches with different shapes, such as different batch sizes or different image sizes. If reshaping of data batch relates to modification of symbol or module, such as changing image layout ordering or switching from training to predicting, module rebinding is required.

See also

BaseModule.forward()

Parameters:
  • data_batch (DataBatch) – Could be anything with similar API implemented.
  • is_train (bool) – Default is None, which means is_train takes the value of self.for_training.
backward(out_grads=None)[source]

Backward computation.

See also

BaseModule.backward()

Parameters:out_grads (NDArray or list of NDArray, optional) – Gradient on the outputs to be propagated back. This parameter is only needed when bind is called on outputs that are not a loss function.
update()[source]

Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. The gradients in the _exec_group will be overwritten using the gradients calculated by the SVRG update rule.

When KVStore is used to update parameters for multi-device or multi-machine training, a copy of the parameters is stored in KVStore. Note that for row_sparse parameters, this function does update the copy of parameters in KVStore, but doesn’t broadcast the updated parameters to all devices / machines. Please call prepare to broadcast row_sparse parameters with the next batch of data.

See also

BaseModule.update()

update_full_grads(train_data)[source]

Computes the gradients over all data w.r.t weights of past m epochs. For distributed env, it will accumulate full grads in the kvstore.

Parameters:train_data (DataIter) – Train data iterator
fit(train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=, arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None, sparse_row_id_fn=None)[source]

Trains the module parameters.

Parameters:
  • train_data (DataIter) – Train DataIter.
  • eval_data (DataIter) – If not None, will be used as validation set and the performance after each epoch will be evaluated.
  • eval_metric (str or EvalMetric) – Defaults to ‘accuracy’. The performance measure used to display during training. Other possible predefined metrics are: ‘ce’ (CrossEntropy), ‘f1’, ‘mae’, ‘mse’, ‘rmse’, ‘top_k_accuracy’.
  • epoch_end_callback (function or list of functions) – Each callback will be called with the current epoch, symbol, arg_params and aux_params.
  • batch_end_callback (function or list of function) – Each callback will be called with a BatchEndParam.
  • kvstore (str or KVStore) – Defaults to ‘local’.
  • optimizer (str or Optimizer) – Defaults to ‘sgd’.
  • optimizer_params (dict) – Defaults to (('learning_rate', 0.01),). The parameters for the optimizer constructor. The default value is not a dict, just to avoid pylint warning on dangerous default values.
  • eval_end_callback (function or list of function) – These will be called at the end of each full evaluation, with the metrics over the entire evaluation set.
  • eval_batch_end_callback (function or list of function) – These will be called at the end of each mini-batch during evaluation.
  • initializer (Initializer) – The initializer is called to initialize the module parameters when they are not already initialized.
  • arg_params (dict) – Defaults to None, if not None, should be existing parameters from a trained model or loaded from a checkpoint (previously saved model). In this case, the value here will be used to initialize the module parameters, unless they are already initialized by the user via a call to init_params or fit. arg_params has a higher priority than initializer.
  • aux_params (dict) – Defaults to None. Similar to arg_params, except for auxiliary states.
  • allow_missing (bool) – Defaults to False. Indicates whether to allow missing parameters when arg_params and aux_params are not None. If this is True, then the missing parameters will be initialized via the initializer.
  • force_rebind (bool) – Defaults to False. Whether to force rebinding the executors if already bound.
  • force_init (bool) – Defaults to False. Indicates whether to force initialization even if the parameters are already initialized.
  • begin_epoch (int) – Defaults to 0. Indicates the starting epoch. Usually, if resumed from a checkpoint saved at a previous training phase at epoch N, then this value should be N+1.
  • num_epoch (int) – Number of epochs for training.
  • sparse_row_id_fn (A callback function) – The function takes data_batch as an input and returns a dict of str -> NDArray. The resulting dict is used for pulling row_sparse parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull.
  • validation_metric (str or EvalMetric) – The performance measure used to display during validation.
prepare(data_batch, sparse_row_id_fn=None)[source]

Prepares two modules for processing a data batch.

Usually involves switching bucket and reshaping. For modules that contain row_sparse parameters in KVStore, it prepares the row_sparse parameters based on the sparse_row_id_fn.

When KVStore is used to update parameters for multi-device or multi-machine training, a copy of the parameters are stored in KVStore. Note that for row_sparse parameters, the update() updates the copy of parameters in KVStore, but doesn’t broadcast the updated parameters to all devices / machines. The prepare function is used to broadcast row_sparse parameters with the next batch of data.

Parameters:
  • data_batch (DataBatch) – The current batch of data for forward computation.
  • sparse_row_id_fn (A callback function) – The function takes data_batch as an input and returns a dict of str -> NDArray. The resulting dict is used for pulling row_sparse parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull.