Source code for mxnet.optimizer.lars

# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""LARS optimizer."""
from __future__ import absolute_import
import numpy
from ..ndarray import (zeros, clip, array,
                       multi_sum_sq, multi_lars,
                       norm as NDnorm,
                       where, ones_like)
from ..ndarray import (sgd_update, sgd_mom_update,
                       mp_sgd_update, mp_sgd_mom_update,
                       preloaded_multi_sgd_update, preloaded_multi_sgd_mom_update,
                       preloaded_multi_mp_sgd_update, preloaded_multi_mp_sgd_mom_update)
from .optimizer import Optimizer, register
from .utils import _flatten_list

__all__ = ['LARS']


[docs]@register class LARS(Optimizer): """the LARS optimizer from 'Large Batch Training of Convolution Networks' \ (https://arxiv.org/abs/1708.03888) Behave mostly like SGD with momentum and weight decay but is scaling \ adaptively the learning for each layer: .. code-block:: w_norm = L2norm(weights) g_norm = L2norm(gradients) if w_norm > 0 and g_norm > 0: lr_layer = lr * w_norm / (g_norm + weight_decay * w_norm + epsilon) else: lr_layer = lr Parameters ---------- learning_rate : float, default 0.1 The initial learning rate. If None, the optimization will use the learning rate from ``lr_scheduler``. If not None, it will overwrite the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` is also None, then it will be set to 0.01 by default. momentum : float, default 0. The momentum value. eta : float, default 0.001 LARS coefficient used to scale the learning rate. epsilon : float, default 1e-8 Small value to avoid division by 0. lazy_update : bool, default False Default is False. If True, lazy updates are applied \ if the storage types of weight and grad are both ``row_sparse``. aggregate_num : int, default 1 Number of weights to be aggregated in a list. They are passed to the optimizer for a single optimization step. use_fused_step : bool, default True Whether or not to use fused kernels for optimizer. When use_fused_step=False, step is called, otherwise, fused_step is called. """ def __init__(self, learning_rate=0.1, momentum=0.0, eta=0.001, epsilon=1e-8, lazy_update=False, use_fused_step=True, aggregate_num=1, **kwargs): super(LARS, self).__init__(learning_rate=learning_rate, use_fused_step=use_fused_step, aggregate_num=aggregate_num, **kwargs) if not self.use_fused_step: assert not lazy_update,\ 'When use_fused_step is set to False, lazy_update has to be turned off.' if lazy_update: assert not self.multi_precision, \ 'When lazy_update is set to True, multi_precision has be turned off.' self.lazy_update = lazy_update self.momentum = momentum self.eta = eta self.epsilon = epsilon self.lazy_update = lazy_update
[docs] def create_state(self, index, weight): momentum = None if self.momentum != 0.0: stype = weight.stype if self.lazy_update else 'default' momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) return momentum
def _l2norm(self, v, rescale=False): """L2 Norm implementation""" v = v.astype('float32') if rescale: v *= self.rescale_grad norm = NDnorm(v) return norm def _get_lars(self, index, weight, grad, wd): """Returns a scaling factor for the learning rate for this layer""" lars = 1.0 name = self.idx2name[index] if index in self.idx2name else str(index) if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'): return lars w_norm = self._l2norm(weight) g_norm = self._l2norm(grad, rescale=True) # calculate lars_trust_ratio ratio = w_norm / g_norm # becomes NaN if ratio == NaN or 0, otherwise 0 nan_or_zero = 1 - ratio / ratio lars = self.eta * w_norm / (g_norm + wd * w_norm + self.epsilon) lars = where(nan_or_zero, ones_like(lars), lars) return lars.asscalar()
[docs] def step(self, indices, weights, grads, states): """Perform an optimization step using gradients and states. Parameters ---------- indices : list of int List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` and `set_wd_mult()`, respectively. weights : list of NDArray List of parameters to be updated. grads : list of NDArray List of gradients of the objective with respect to this parameter. states : List of any obj List of state returned by `create_state()`. """ for index, weight, grad, state in zip(indices, weights, grads, states): self._update_count(index) lr = self._get_lr(index) wd = self._get_wd(index) # compute lars # clip grad + wd * weight is performed after computing lars lars = self._get_lars(index, weight, grad, wd) lr *= lars # preprocess grad grad *= self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) grad += wd * weight # update mom mom = state if mom is not None: mom[:] *= self.momentum mom[:] -= lr * grad else: mom = -lr * grad # update weight weight[:] += mom
[docs] def fused_step(self, indices, weights, grads, states): """Perform a fused optimization step using gradients and states. Fused kernel is used for update. Parameters ---------- indices : list of int List of unique indices of the parameters into the individual learning rates and weight decays. Learning rates and weight decay may be set via `set_lr_mult()` and `set_wd_mult()`, respectively. weights : list of NDArray List of parameters to be updated. grads : list of NDArray List of gradients of the objective with respect to this parameter. states : List of any obj List of state returned by `create_state()`. """ aggregate = self.aggregate_num > 1 for weight, grad in zip(weights, grads): aggregate = (aggregate and weight.stype == 'default' and grad.stype == 'default') self._update_count(indices) lrs = self._get_lrs(indices) wds = self._get_wds(indices) kwargs = {'rescale_grad': self.rescale_grad} if self.momentum > 0: kwargs['momentum'] = self.momentum if self.clip_gradient is not None: kwargs['clip_gradient'] = self.clip_gradient if aggregate: nb_params = len(indices) names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices] lars_idx = [i for i in range(nb_params) if not(names[i].endswith('gamma') or names[i].endswith('beta') or names[i].endswith('bias'))] nb_lars = len(lars_idx) no_lars_idx = [i for i in range(nb_params) if (names[i].endswith('gamma') or names[i].endswith('beta') or names[i].endswith('bias'))] cur_ctx = weights[0].context full_idx = lars_idx + no_lars_idx new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32') new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32') new_weights = [weights[i] for i in full_idx] new_grads = [grads[i] for i in full_idx] new_states = [states[i] for i in full_idx] if nb_lars > 0: w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars) g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars) multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars], eta=self.eta, eps=self.epsilon, rescale_grad=self.rescale_grad, out=new_lrs[:nb_lars]) # Same than usual using preloaded sgd functions multi_precision = self.multi_precision and weights[0].dtype == numpy.float16 if not multi_precision: if self.momentum > 0: preloaded_multi_sgd_mom_update( *(_flatten_list(zip(new_weights, new_grads, new_states)) + [new_lrs, new_wds]), out=new_weights, num_weights=len(new_weights), **kwargs) else: preloaded_multi_sgd_update( *(_flatten_list(zip(new_weights, new_grads)) + [new_lrs, new_wds]), out=new_weights, num_weights=len(new_weights), **kwargs) else: states = list(zip(*states)) weights32, moms = states if self.momentum > 0: preloaded_multi_mp_sgd_mom_update( *(_flatten_list(zip(new_weights, new_grads, moms, weights32)) + [new_lrs, new_wds]), out=new_weights, num_weights=len(new_weights), **kwargs) else: preloaded_multi_mp_sgd_update( *(_flatten_list(zip(new_weights, new_grads, weights32)) + [new_lrs, new_wds]), out=new_weights, num_weights=len(new_weights), **kwargs) else: for i, (index, weight, grad, state) in enumerate(zip(indices, weights, grads, states)): wd = wds[i] lr = lrs[i] lr *= self._get_lars(index, weight, grad, wd) multi_precision = self.multi_precision and weights[0].dtype == numpy.float16 if not multi_precision: mom = state if state is not None: sgd_mom_update(weight, grad, mom, out=weight, lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) else: sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) else: weight32, mom = state if mom is not None: mp_sgd_mom_update(weight, grad, mom, weight32, out=weight, lr=lr, wd=wd, **kwargs) else: mp_sgd_update(weight, grad, weight32, out=weight, lr=lr, wd=wd, **kwargs)
[docs] def update_multi_precision(self, indices, weights, grads, states): """Override update_multi_precision. """ if self.use_fused_step: self.update(indices, weights, grads, states) else: super(LARS, self).update_multi_precision(indices, weights, grads, states)