mxnet.optimizer

Optimizer API of MXNet.

Classes

Optimizer([rescale_grad, param_idx2name, …])

The base class inherited by all optimizers.

Test(**kwargs)

The Test optimizer

Updater(optimizer)

Updater for kvstore.

SGD([learning_rate, momentum, lazy_update, …])

The SGD optimizer with momentum and weight decay.

SGLD([learning_rate, use_fused_step])

Stochastic Gradient Riemannian Langevin Dynamics.

Signum([learning_rate, momentum, wd_lh, …])

The Signum optimizer that takes the sign of gradient or momentum.

DCASGD([learning_rate, momentum, lamda, …])

The DCASGD optimizer.

NAG([learning_rate, momentum, …])

Nesterov accelerated gradient.

AdaBelief([learning_rate, beta1, beta2, …])

The AdaBelief optimizer.

AdaGrad([learning_rate, epsilon, use_fused_step])

AdaGrad optimizer.

AdaDelta([learning_rate, rho, epsilon, …])

The AdaDelta optimizer.

Adam([learning_rate, beta1, beta2, epsilon, …])

The Adam optimizer.

Adamax([learning_rate, beta1, beta2, …])

The AdaMax optimizer.

Nadam([learning_rate, beta1, beta2, …])

The Nesterov Adam optimizer.

Ftrl([learning_rate, lamda1, beta, …])

The Ftrl optimizer.

FTML([learning_rate, beta1, beta2, epsilon, …])

The FTML optimizer.

LARS([learning_rate, momentum, eta, …])

the LARS optimizer from ‘Large Batch Training of Convolution Networks’ (https://arxiv.org/abs/1708.03888)

LAMB([learning_rate, beta1, beta2, epsilon, …])

LAMB Optimizer.

RMSProp([learning_rate, rho, momentum, …])

The RMSProp optimizer.

LANS([learning_rate, beta1, beta2, epsilon, …])

LANS Optimizer.

Functions

create(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

register(klass)

Registers a new optimizer.

get_updater(optimizer)

Returns a closure of the updater needed for kvstore.

class Optimizer(rescale_grad=1.0, param_idx2name=None, wd=0.0, clip_gradient=None, learning_rate=None, lr_scheduler=None, sym=None, begin_num_update=0, multi_precision=False, param_dict=None, aggregate_num=None, use_fused_step=None, **kwargs)[source]

Bases: object

The base class inherited by all optimizers.

Parameters
  • rescale_grad (float, optional, default 1.0) – Multiply the gradient with rescale_grad before updating. Often choose to be 1.0/batch_size.

  • param_idx2name (dict from int to string, optional, default None) – A dictionary that maps int index to string name.

  • clip_gradient (float, optional, default None) – Clip the gradient by projecting onto the box [-clip_gradient, clip_gradient].

  • learning_rate (float) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • lr_scheduler (LRScheduler, optional, default None) – The learning rate scheduler.

  • wd (float, optional, default 0.0) – The weight decay (or L2 regularization) coefficient. Modifies objective by adding a penalty for having large weights.

  • sym (Symbol, optional, default None) – The Symbol this optimizer is applying to.

  • begin_num_update (int, optional, default 0) – The initial number of updates.

  • multi_precision (bool, optional, default False) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

  • param_dict (dict of int -> gluon.Parameter, default None) – Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes such as lr_mult, wd_mult, etc. param_dict shall not be deep copied.

  • aggregate_num (int, optional, default None) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step. In default, only one weight is aggregated. When aggregate_num is set to numpy.inf, all the weights are aggregated.

  • use_fused_step (bool, optional, default None) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

  • Properties

  • ----------

  • learning_rate – The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate.

Methods

create_optimizer(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

create_state(index, weight)

Creates auxiliary state for a given weight.

create_state_multi_precision(index, weight)

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

register(klass)

Registers a new optimizer.

set_learning_rate(lr)

Sets a new learning rate of the optimizer.

set_lr_mult(args_lr_mult)

Sets an individual learning rate multiplier for each parameter.

set_wd_mult(args_wd_mult)

Sets an individual weight decay multiplier for each parameter.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

update(indices, weights, grads, states)

Call step to perform a single optimization update if use_fused_step is False, otherwise fused_step is called.

update_multi_precision(indices, weights, …)

Call step to perform a single optimization update if use_fused_step is False, otherwise fused_step is called.

static create_optimizer(name, **kwargs)[source]

Instantiates an optimizer with a given name and kwargs.

Note

We can use the alias create for Optimizer.create_optimizer.

Parameters
  • name (str) – Name of the optimizer. Should be the name of a subclass of Optimizer. Case insensitive.

  • kwargs (dict) – Parameters for the optimizer.

Returns

An instantiated optimizer.

Return type

Optimizer

Examples

>>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
>>> type(sgd)
<class 'mxnet.optimizer.SGD'>
>>> adam = mx.optimizer.create('adam', learning_rate=.1)
>>> type(adam)
<class 'mxnet.optimizer.Adam'>
create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

create_state_multi_precision(index, weight)[source]

Creates auxiliary state for a given weight, including FP32 high precision copy if original weight is FP16.

This method is provided to perform automatic mixed precision training for optimizers that do not support it themselves.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. New operators that fuses optimizer’s update should be put in this function.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

static register(klass)[source]

Registers a new optimizer.

Once an optimizer is registered, we can create an instance of this optimizer with create_optimizer later.

Examples

>>> @mx.optimizer.Optimizer.register
... class MyOptimizer(mx.optimizer.Optimizer):
...     pass
>>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
>>> print(type(optim))
<class '__main__.MyOptimizer'>
set_learning_rate(lr)[source]

Sets a new learning rate of the optimizer.

Parameters

lr (float) – The new learning rate of the optimizer.

set_lr_mult(args_lr_mult)[source]

Sets an individual learning rate multiplier for each parameter.

If you specify a learning rate multiplier for a parameter, then the learning rate for the parameter will be set as the product of the global learning rate self.lr and its multiplier.

Note

The default learning rate multiplier of a Variable can be set with lr_mult argument in the constructor.

Parameters

args_lr_mult (dict of str/int to float) –

For each of its key-value entries, the learning rate multipler for the parameter specified in the key will be set as the given value.

You can specify the parameter with either its name or its index. If you use the name, you should pass sym in the constructor, and the name you specified in the key of args_lr_mult should match the name of the parameter in sym. If you use the index, it should correspond to the index of the parameter used in the update method.

Specifying a parameter by its index is only supported for backward compatibility, and we recommend to use the name instead.

set_wd_mult(args_wd_mult)[source]

Sets an individual weight decay multiplier for each parameter.

Note

The default weight decay multiplier for a Variable can be set with its wd_mult argument in the constructor.

Parameters

args_wd_mult (dict of string/int to float) –

For each of its key-value entries, the weight decay multipler for the parameter specified in the key will be set as the given value.

You can specify the parameter with either its name or its index. If you use the name, you should pass sym in the constructor, and the name you specified in the key of args_lr_mult should match the name of the parameter in sym. If you use the index, it should correspond to the index of the parameter used in the update method.

Specifying a parameter by its index is only supported for backward compatibility, and we recommend to use the name instead.

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update(indices, weights, grads, states)[source]
Call step to perform a single optimization update if use_fused_step is False,

otherwise fused_step is called.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]
Call step to perform a single optimization update if use_fused_step is False,

otherwise fused_step is called. Mixed precision version.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Test(**kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Test optimizer

Methods

create_state(index, weight)

Creates a state to duplicate weight.

step(indices, weights, grads, states)

Performs w += rescale_grad * grad.

create_state(index, weight)[source]

Creates a state to duplicate weight.

step(indices, weights, grads, states)[source]

Performs w += rescale_grad * grad.

create(name, **kwargs)

Instantiates an optimizer with a given name and kwargs.

Note

We can use the alias create for Optimizer.create_optimizer.

Parameters
  • name (str) – Name of the optimizer. Should be the name of a subclass of Optimizer. Case insensitive.

  • kwargs (dict) – Parameters for the optimizer.

Returns

An instantiated optimizer.

Return type

Optimizer

Examples

>>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
>>> type(sgd)
<class 'mxnet.optimizer.SGD'>
>>> adam = mx.optimizer.create('adam', learning_rate=.1)
>>> type(adam)
<class 'mxnet.optimizer.Adam'>
register(klass)

Registers a new optimizer.

Once an optimizer is registered, we can create an instance of this optimizer with create_optimizer later.

Examples

>>> @mx.optimizer.Optimizer.register
... class MyOptimizer(mx.optimizer.Optimizer):
...     pass
>>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
>>> print(type(optim))
<class '__main__.MyOptimizer'>
class Updater(optimizer)[source]

Bases: object

Updater for kvstore.

Methods

get_states([dump_optimizer])

Gets updater states.

set_states(states)

Sets updater states.

sync_state_context(state, context)

sync state context.

get_states(dump_optimizer=False)[source]

Gets updater states.

Parameters

dump_optimizer (bool, default False) – Whether to also save the optimizer itself. This would also save optimizer information such as learning rate and weight decay schedules.

set_states(states)[source]

Sets updater states.

sync_state_context(state, context)[source]

sync state context.

get_updater(optimizer)[source]

Returns a closure of the updater needed for kvstore.

Parameters

optimizer (Optimizer) – The optimizer.

Returns

updater – The closure of the updater.

Return type

function

class SGD(learning_rate=0.1, momentum=0.0, lazy_update=False, multi_precision=False, use_fused_step=True, aggregate_num=1, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The SGD optimizer with momentum and weight decay.

If the storage types of grad is row_sparse and lazy_update is True, lazy updates are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(rescale_grad * grad[row] + wd * weight[row], clip_gradient)
    state[row] = momentum[row] * state[row] + lr * rescaled_grad[row]
    weight[row] = weight[row] - state[row]

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

update_multi_precision(indices, weights, …)

Override update_multi_precision.

The sparse update only updates the momentum for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

In the case when update_on_kvstore is set to False (either globally via MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in Trainer) SGD optimizer can perform aggregated update of parameters, which may lead to improved performance. The aggregation size is controlled by aggregate_num and defaults to 4.

Otherwise, standard updates are applied by:

rescaled_grad = clip(rescale_grad * grad, clip_gradient)) + wd * weight
state = momentum * state + lr * rescaled_grad
weight = weight - state

For details of the update algorithm see sgd_update and sgd_mom_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • momentum (float, default 0.) – The momentum value.

  • lazy_update (bool, default False) – Default is False. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • multi_precision (bool, default False) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

  • aggregate_num (int, default 1) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]

Override update_multi_precision.

class SGLD(learning_rate=0.1, use_fused_step=False, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

Stochastic Gradient Riemannian Langevin Dynamics.

This class implements the optimizer described in the paper Stochastic Gradient Riemannian Langevin Dynamics on the Probability Simplex, available at https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • use_fused_step (bool, default False) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Signum(learning_rate=0.01, momentum=0.9, wd_lh=0.0, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Signum optimizer that takes the sign of gradient or momentum.

The optimizer updates the weight by:

rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
state = momentum * state + (1-momentum)*rescaled_grad
weight = (1 - lr * wd_lh) * weight - lr * sign(state)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

References

Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018). signSGD: Compressed Optimisation for Non-Convex Problems. In ICML’18.

See: https://arxiv.org/abs/1802.04434

For details of the update algorithm see signsgd_update and signum_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.01) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • momentum (float, optional) – The momentum value.

  • wd_lh (float, optional) – The amount of decoupled weight decay regularization, see details in the original paper at:https://arxiv.org/abs/1711.05101

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class DCASGD(learning_rate=0.1, momentum=0.0, lamda=0.04, use_fused_step=False, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The DCASGD optimizer.

This class implements the optimizer described in Asynchronous Stochastic Gradient Descent with Delay Compensation for Distributed Deep Learning, available at https://arxiv.org/abs/1609.08326.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • momentum (float, optional) – The momentum value.

  • lamda (float, optional) – Scale DC value.

  • use_fused_step (bool, default False) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class NAG(learning_rate=0.1, momentum=0.9, multi_precision=False, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

Nesterov accelerated gradient.

This optimizer updates each weight by:

grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
state = momentum * state + lr * grad
weight = weight - (momentum * state + lr * grad)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

update_multi_precision(indices, weights, …)

Override update_multi_precision.

Parameters
  • learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • momentum (float, default 0.9) – The momentum value.

  • multi_precision (bool, default False) – Flag to control the internal precision of the optimizer. False: results in using the same precision as the weights (default), True: makes internal 32-bit copy of the weights and applies gradients in 32-bit precision even if actual weights used in the model have lower precision. Turning this on can improve convergence and accuracy when training with float16.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]

Override update_multi_precision.

class AdaBelief(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, correct_bias=True, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The AdaBelief optimizer.

This class implements the optimizer described in Adapting Stepsizes by the Belief in Observed Gradients,

available at https://arxiv.org/pdf/2010.07468.pdf.

Methods

create_state(index, weight)

state creation function.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

Updates are applied by:

grad = clip(grad * rescale_grad, clip_gradient) + wd * w
m = beta1 * m + (1 - beta1) * grad
s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
w = w - lr * (m / (sqrt(s) + epsilon))

Also, we can turn off the bias correction term and the updates are as follows:

grad = clip(grad * rescale_grad, clip_gradient) + wd * w
m = beta1 * m + (1 - beta1) * grad
s = beta2 * s + (1 - beta2) * ((grad - m)**2) + epsilon
lr = learning_rate
w = w - lr * (m / (sqrt(s) + epsilon))

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-6) – Small value to avoid division by 0.

  • correct_bias (bool, default True) – Can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

state creation function.

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class AdaGrad(learning_rate=0.01, epsilon=1e-06, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

AdaGrad optimizer.

This class implements the AdaGrad optimizer described in Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.

This optimizer updates each weight by:

grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
history += square(grad)
weight -= learning_rate * grad / (sqrt(history) + epsilon)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.01) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • epsilon (float, default 1e-6) – Small value to avoid division by 0.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False or grad is not sparse, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class AdaDelta(learning_rate=1.0, rho=0.9, epsilon=1e-06, use_fused_step=False, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The AdaDelta optimizer.

This class implements AdaDelta, an optimizer described in ADADELTA: An adaptive learning rate method, available at https://arxiv.org/abs/1212.5701.

This optimizer updates each weight by:

grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
acc_grad = rho * acc_grad + (1. - rho) * grad * grad
delta = sqrt(acc_delta + epsilon) / sqrt(acc_grad + epsilon) * grad
acc_delta = rho * acc_delta + (1. - rho) * delta * delta
weight -= learning_rate * delta

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 1.0) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • rho (float, default 0.9) – Decay rate for both squared gradients and delta.

  • epsilon (float, default 1e-6) – Small value to avoid division by 0.

  • use_fused_step (bool, default False) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=False, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Adam optimizer.

This class implements the optimizer described in Adam: A Method for Stochastic Optimization, available at http://arxiv.org/abs/1412.6980.

If the storage types of grad is row_sparse, and lazy_update is True, lazy updates at step t are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient) + wd * weight[row]
    m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
    v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
    lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
    w[row] = w[row] - lr * m[row] / (sqrt(v[row]) + epsilon)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

The lazy update only updates the mean and var for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

Otherwise, standard updates at step t are applied by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
w = w - lr * m / (sqrt(v) + epsilon)

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

For details of the update algorithm, see adam_update.

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • lazy_update (bool, default False) – Default is False. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Adamax(learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-08, use_fused_step=False, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The AdaMax optimizer.

It is a variant of Adam based on the infinity norm available at http://arxiv.org/abs/1412.6980 Section 7.

The optimizer updates the weight by:

grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
m = beta1 * m_t + (1 - beta1) * grad
u = maximum(beta2 * u, abs(grad))
weight -= lr / (1 - beta1**t) * m / (u + epsilon)

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.002) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • use_fused_step (bool, default False) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Nadam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, schedule_decay=0.004, use_fused_step=False, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Nesterov Adam optimizer.

Much like Adam is essentially RMSprop with momentum, Nadam is Adam RMSprop with Nesterov momentum available at http://cs229.stanford.edu/proj2015/054_report.pdf.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • schedule_decay (float, default 0.004) – Exponential decay rate for the momentum schedule

  • use_fused_step (bool, default False) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class Ftrl(learning_rate=0.1, lamda1=0.01, beta=1.0, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The Ftrl optimizer.

Referenced from Ad Click Prediction: a View from the Trenches, available at http://dl.acm.org/citation.cfm?id=2488200.

eta :
\[\eta_{t,i} = \frac{learningrate}{\beta+\sqrt{\sum_{s=1}^tg_{s,i}^2}}\]

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

The optimizer updates the weight by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
n += rescaled_grad**2
w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)

If the storage types of weight, state and grad are all row_sparse, sparse updates are applied by:

for row in grad.indices:
    rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
    z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
    n[row] += rescaled_grad[row]**2
    w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)

The sparse update only updates the z and n for the weights whose row_sparse gradient indices appear in the current batch, rather than updating it for all indices. Compared with the original update, it can provide large improvements in model training throughput for some applications. However, it provides slightly different semantics than the original update, and may lead to different empirical results.

For details of the update algorithm, see ftrl_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • lamda1 (float, default 0.01) – L1 regularization coefficient.

  • beta (float, default 1.0) – Per-coordinate learning rate correlation parameter.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class FTML(learning_rate=0.0025, beta1=0.6, beta2=0.999, epsilon=1e-08, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The FTML optimizer.

This class implements the optimizer described in FTML - Follow the Moving Leader in Deep Learning, available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.

Denote time step by t. The optimizer updates the weight by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
v = beta2 * v + (1 - beta2) * square(rescaled_grad)
d_t = (1 - power(beta1, t)) / lr * (square_root(v / (1 - power(beta2, t))) + epsilon)
z = beta1 * z + (1 - beta1) * rescaled_grad - (d_t - beta1 * d_(t-1)) * weight
weight = - z / d_t

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

For details of the update algorithm, see ftml_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.0025) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.6) – 0 < beta1 < 1. Generally close to 0.5.

  • beta2 (float, default 0.999) – 0 < beta2 < 1. Generally close to 1.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class LARS(learning_rate=0.1, momentum=0.0, eta=0.001, epsilon=1e-08, lazy_update=False, use_fused_step=True, aggregate_num=1, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

the LARS optimizer from ‘Large Batch Training of Convolution Networks’ (https://arxiv.org/abs/1708.03888)

Behave mostly like SGD with momentum and weight decay but is scaling adaptively the learning for each layer:

w_norm = L2norm(weights)
g_norm = L2norm(gradients)
if w_norm > 0 and g_norm > 0:
    lr_layer = lr * w_norm / (g_norm + weight_decay * w_norm + epsilon)
else:
    lr_layer = lr

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

update_multi_precision(indices, weights, …)

Override update_multi_precision.

Parameters
  • learning_rate (float, default 0.1) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • momentum (float, default 0.) – The momentum value.

  • eta (float, default 0.001) – LARS coefficient used to scale the learning rate.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • lazy_update (bool, default False) – Default is False. If True, lazy updates are applied if the storage types of weight and grad are both row_sparse.

  • aggregate_num (int, default 1) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]

Override update_multi_precision.

class LAMB(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, lower_bound=None, upper_bound=None, bias_correction=True, aggregate_num=4, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

LAMB Optimizer.

Referenced from ‘Large Batch Optimization for Deep Learning: Training BERT in 76 minutes’ (https://arxiv.org/pdf/1904.00962.pdf)

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-6) – Small value to avoid division by 0.

  • lower_bound (float, default None) – Lower limit of norm of weight

  • upper_bound (float, default None) – Upper limit of norm of weight

  • bias_correction (bool, default True) – Whether or not to apply bias correction

  • aggregate_num (int, default 4) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step. In default, all the weights are aggregated.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

update_multi_precision(indices, weights, …)

Override update_multi_precision.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]

Override update_multi_precision.

class RMSProp(learning_rate=0.001, rho=0.9, momentum=0.9, epsilon=1e-08, centered=False, clip_weights=None, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

The RMSProp optimizer.

Two versions of RMSProp are implemented:

If centered=False, we follow http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by Tieleman & Hinton, 2012. For details of the update algorithm see rmsprop_update.

If centered=True, we follow http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45) by Alex Graves, 2013. For details of the update algorithm see rmspropalex_update.

This optimizer accepts the following parameters in addition to those accepted by Optimizer.

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • rho (float, default 0.9) – A decay factor of moving average over past squared gradient.

  • momentum (float, default 0.9) – Heavy ball momentum factor. Only used if centered`=``True`.

  • epsilon (float, default 1e-8) – Small value to avoid division by 0.

  • centered (bool, default False) –

    Flag to control which version of RMSProp to use.:

    True: will use Graves's version of `RMSProp`,
    False: will use Tieleman & Hinton's version of `RMSProp`.
    

  • clip_weights (float, optional) – Clips weights into range [-clip_weights, clip_weights].

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform an optimization step using gradients and states.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform an optimization step using gradients and states.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

class LANS(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, lower_bound=None, upper_bound=None, aggregate_num=4, use_fused_step=True, **kwargs)[source]

Bases: mxnet.optimizer.optimizer.Optimizer

LANS Optimizer.

Referenced from ‘Accelerated Large Batch Optimization of BERT Pretraining in 54 minutes’ (http://arxiv.org/abs/2006.13484)

Parameters
  • learning_rate (float, default 0.001) – The initial learning rate. If None, the optimization will use the learning rate from lr_scheduler. If not None, it will overwrite the learning rate in lr_scheduler. If None and lr_scheduler is also None, then it will be set to 0.01 by default.

  • beta1 (float, default 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, default 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, default 1e-6) – Small value to avoid division by 0.

  • lower_bound (float, default None) – Lower limit of norm of weight

  • upper_bound (float, default None) – Upper limit of norm of weight

  • aggregate_num (int, default 4) – Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step. In default, all the weights are aggregated.

  • use_fused_step (bool, default True) – Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called.

Methods

create_state(index, weight)

Creates auxiliary state for a given weight.

fused_step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

step(indices, weights, grads, states)

Perform a fused optimization step using gradients and states.

update_multi_precision(indices, weights, …)

Override update_multi_precision.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

fused_step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

step(indices, weights, grads, states)[source]

Perform a fused optimization step using gradients and states. Fused kernel is used for update.

Parameters
  • indices (list of int) – List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weights (list of NDArray) – List of parameters to be updated.

  • grads (list of NDArray) – List of gradients of the objective with respect to this parameter.

  • states (List of any obj) – List of state returned by create_state().

update_multi_precision(indices, weights, grads, states)[source]

Override update_multi_precision.