Source code for mxnet.initializer

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Weight initializer."""

import re
import logging
import warnings
import json
from math import sqrt
import numpy as np
from .base import string_types
from .ndarray import NDArray, load
from . import random
from . import registry
from . import ndarray
from . util import is_np_array
from . import numpy as _mx_np  # pylint: disable=reimported


# inherit str for backward compatibility
[docs]class InitDesc(str): """ Descriptor for the initialization pattern. Parameters ---------- name : str Name of variable. attrs : dict of str to str Attributes of this variable taken from ``Symbol.attr_dict``. global_init : Initializer Global initializer to fallback to. """ def __new__(cls, name, attrs=None, global_init=None): ret = super(InitDesc, cls).__new__(cls, name) ret.attrs = attrs or {} ret.global_init = global_init return ret
[docs]class Initializer(object): """The base class of an initializer.""" def __init__(self, **kwargs): self._kwargs = kwargs self._verbose = False self._print_func = None
[docs] def set_verbosity(self, verbose=False, print_func=None): """Switch on/off verbose mode Parameters ---------- verbose : bool switch on/off verbose mode print_func : function A function that computes statistics of initialized arrays. Takes an `NDArray` and returns an `str`. Defaults to mean absolute value str((abs(x)/size(x)).asscalar()). """ self._verbose = verbose if print_func is None: def asum_stat(x): """returns |x|/size(x), async execution.""" return str((ndarray.norm(x)/sqrt(x.size)).asscalar()) print_func = asum_stat self._print_func = print_func return self
def _verbose_print(self, desc, init, arr): """Internal verbose print function Parameters ---------- desc : InitDesc or str name of the array init : str initializer pattern arr : NDArray initialized array """ if self._verbose and self._print_func: logging.info('Initialized %s as %s: %s', desc, init, self._print_func(arr))
[docs] def dumps(self): """Saves the initializer to string Returns ------- str JSON formatted string that describes the initializer. Examples -------- >>> # Create initializer and retrieve its parameters ... >>> init = mx.init.Normal(0.5) >>> init.dumps() '["normal", {"sigma": 0.5}]' >>> init = mx.init.Xavier(factor_type="in", magnitude=2.34) >>> init.dumps() '["xavier", {"rnd_type": "uniform", "magnitude": 2.34, "factor_type": "in"}]' """ return json.dumps([self.__class__.__name__.lower(), self._kwargs])
def __call__(self, desc, arr): """Initialize an array Parameters ---------- desc : InitDesc Initialization pattern descriptor. arr : NDArray The array to be initialized. """ if not isinstance(desc, InitDesc): self._legacy_init(desc, arr) return if desc.global_init is None: desc.global_init = self init = desc.attrs.get('__init__', "") if init: # when calling Variable initializer create(init)._init_weight(desc, arr) self._verbose_print(desc, init, arr) else: # register nnvm::FSetInputVariableAttrs in the backend for new patterns # don't add new cases here. if desc.endswith('weight'): self._init_weight(desc, arr) self._verbose_print(desc, 'weight', arr) elif desc.endswith('bias'): self._init_bias(desc, arr) self._verbose_print(desc, 'bias', arr) elif desc.endswith('gamma'): self._init_gamma(desc, arr) self._verbose_print(desc, 'gamma', arr) elif desc.endswith('beta'): self._init_beta(desc, arr) self._verbose_print(desc, 'beta', arr) elif desc.endswith('min'): self._init_zero(desc, arr) self._verbose_print(desc, 'min', arr) elif desc.endswith('max'): self._init_one(desc, arr) self._verbose_print(desc, 'max', arr) elif desc.endswith('weight_quantize'): self._init_quantized_weight(desc, arr) self._verbose_print(desc, 'weight_quantize', arr) elif desc.endswith('bias_quantize'): self._init_quantized_bias(desc, arr) self._verbose_print(desc, 'bias_quantize', arr) else: self._init_default(desc, arr) def _legacy_init(self, name, arr): """Legacy initialization method. Parameters ---------- name : str Name of corresponding NDArray. arr : NDArray NDArray to be initialized. """ warnings.warn( "\033[91mCalling initializer with init(str, NDArray) has been deprecated." \ "please use init(mx.init.InitDesc(...), NDArray) instead.\033[0m", DeprecationWarning, stacklevel=3) if not isinstance(name, string_types): raise TypeError('name must be string') if not isinstance(arr, NDArray): raise TypeError('arr must be NDArray') if name.startswith('upsampling'): self._init_bilinear(name, arr) elif name.startswith('stn_loc') and name.endswith('weight'): self._init_zero(name, arr) elif name.startswith('stn_loc') and name.endswith('bias'): self._init_loc_bias(name, arr) elif name.endswith('bias'): self._init_bias(name, arr) elif name.endswith('gamma'): self._init_gamma(name, arr) elif name.endswith('beta'): self._init_beta(name, arr) elif name.endswith('weight'): self._init_weight(name, arr) elif name.endswith("moving_mean"): self._init_zero(name, arr) elif name.endswith("moving_var"): self._init_one(name, arr) elif name.endswith("moving_inv_var"): self._init_zero(name, arr) elif name.endswith("moving_avg"): self._init_zero(name, arr) elif name.endswith('min'): self._init_zero(name, arr) elif name.endswith('max'): self._init_one(name, arr) else: self._init_default(name, arr) def _init_bilinear(self, _, arr): weight = np.zeros(np.prod(arr.shape), dtype='float32') shape = arr.shape f = np.ceil(shape[3] / 2.) c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape) def _init_loc_bias(self, _, arr): shape = arr.shape assert(shape[0] == 6) arr[:] = np.array([1.0, 0, 0, 0, 1.0, 0]) def _init_zero(self, _, arr): arr[:] = 0.0 def _init_one(self, _, arr): arr[:] = 1.0 def _init_bias(self, _, arr): arr[:] = 0.0 def _init_quantized_bias(self, _, arr): arr[:] = 0 def _init_gamma(self, _, arr): arr[:] = 1.0 def _init_beta(self, _, arr): arr[:] = 0.0 def _init_weight(self, name, arr): """Abstract method to Initialize weight.""" raise NotImplementedError("Must override it") def _init_quantized_weight(self, _, arr): _arr = random.randint(-127, 127, dtype='int32').asnumpy() arr[:] = np.int8(_arr) def _init_default(self, name, _): raise ValueError( f'Unknown initialization pattern for {name}. ' \ 'Default initialization is now limited to '\ '"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \ 'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern') def __eq__(self, other): if not isinstance(other, Initializer): return NotImplemented # pylint: disable=unidiomatic-typecheck return type(self) is type(other) and self._kwargs == other._kwargs
# pylint: disable=invalid-name _register = registry.get_register_func(Initializer, 'initializer') alias = registry.get_alias_func(Initializer, 'initializer') create = registry.get_create_func(Initializer, 'initializer') # pylint: enable=invalid-name
[docs]def register(klass): """Registers a custom initializer. Custom initializers can be created by extending `mx.init.Initializer` and implementing the required functions like `_init_weight` and `_init_bias`. The created initializer must be registered using `mx.init.register` before it can be called by name. Parameters ---------- klass : class A subclass of `mx.init.Initializer` that needs to be registered as a custom initializer. Example ------- >>> # Create and register a custom initializer that ... # initializes weights to 0.1 and biases to 1. ... >>> @mx.init.register ... @alias('myinit') ... class CustomInit(mx.init.Initializer): ... def __init__(self): ... super(CustomInit, self).__init__() ... def _init_weight(self, _, arr): ... arr[:] = 0.1 ... def _init_bias(self, _, arr): ... arr[:] = 1 ... >>> # block is an instance of 'mxnet.gluon.Block' ... >>> block.initialize(CustomInit()) """ return _register(klass)
[docs]class Load(object): """Initializes variables by loading data from file or dict. **Note** Load will drop ``arg:`` or ``aux:`` from name and initialize the variables that match with the prefix dropped. Parameters ---------- param: str or dict of str->`NDArray` Parameter file or dict mapping name to NDArray. default_init: Initializer Default initializer when name is not found in `param`. verbose: bool Flag for enabling logging of source when initializing. """ def __init__(self, param, default_init=None, verbose=False): if isinstance(param, str): param = load(param) assert isinstance(param, dict) self.param = {} for name, arr in param.items(): if name.startswith('arg:') or name.startswith('aux:'): self.param[name[4:]] = arr else: self.param[name] = arr self.default_init = default_init self.verbose = verbose def __call__(self, name, arr): if name in self.param: assert arr.shape == self.param[name].shape, \ f'Parameter {name} cannot be initialized from loading. ' + \ f'Shape mismatch, target {str(arr.shape)} vs loaded {self.param[name].shape}' arr[:] = self.param[name] if self.verbose: logging.info('Initialized %s by loading', name) else: assert self.default_init is not None, \ f"Cannot Initialize {name}. Not found in loaded param " + \ "and no default Initializer is provided." self.default_init(name, arr) if self.verbose: logging.info('Initialized %s by default', name)
[docs]class Mixed(object): """Initialize parameters using multiple initializers. Parameters ---------- patterns: list of str List of regular expressions matching parameter names. initializers: list of Initializer List of initializers corresponding to `patterns`. Example ------- >>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize biases to zero ... # and every other parameter to random values with uniform distribution. ... >>> init = mx.initializer.Mixed(['bias', '.*'], [mx.init.Zero(), mx.init.Uniform(0.1)]) >>> block.initialize(init) >>> >>> for dictionary in module.get_params(): ... for key in dictionary: ... print(key) ... print(dictionary[key].asnumpy()) ... fullyconnected1_weight [[ 0.0097627 0.01856892 0.04303787]] fullyconnected1_bias [ 0.] """ def __init__(self, patterns, initializers): assert len(patterns) == len(initializers) self.map = list(zip([re.compile(p) for p in patterns], initializers)) def __call__(self, name, arr): for prog, init in self.map: if prog.match(name): init(name, arr) return raise ValueError('Parameter name %s did not match any pattern. Consider' + 'add a ".*" pattern at the and with default Initializer.')
[docs]@register @alias("zeros") class Zero(Initializer): """Initializes weights to zero. Example ------- >>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights to zero. ... >>> init = mx.initializer.Zero() >>> module.initialize(init) >>> for dictionary in module.get_params(): ... for key in dictionary: ... print(key) ... print(dictionary[key].asnumpy()) ... fullyconnected0_weight [[ 0. 0. 0.]] """ def __init__(self): super(Zero, self).__init__() def _init_weight(self, _, arr): arr[:] = 0
[docs]@register @alias("ones") class One(Initializer): """Initializes weights to one. Example ------- >>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights to one. ... >>> init = mx.initializer.One() >>> module.initialize(init) >>> for dictionary in module.get_params(): ... for key in dictionary: ... print(key) ... print(dictionary[key].asnumpy()) ... fullyconnected0_weight [[ 1. 1. 1.]] """ def __init__(self): super(One, self).__init__() def _init_weight(self, _, arr): arr[:] = 1
[docs]@register class Constant(Initializer): """Initializes the weights to a given value. The value passed in can be a scalar or a NDarray that matches the shape of the parameter to be set. Parameters ---------- value : float, NDArray Value to set. """ def __init__(self, value): super(Constant, self).__init__(value=value) self.value = value def _init_weight(self, _, arr): arr[:] = self.value
[docs] def dumps(self): val = self._kwargs['value'] if not np.isscalar(val): self._kwargs['value'] = val.tolist() if isinstance(val, np.ndarray) else val.asnumpy().tolist() return json.dumps([self.__class__.__name__.lower(), self._kwargs])
[docs]@register class Uniform(Initializer): """Initializes weights with random values uniformly sampled from a given range. Parameters ---------- scale : float, optional The bound on the range of the generated random values. Values are generated from the range [-`scale`, `scale`]. Default scale is 0.07. Example ------- >>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights >>> # to random values uniformly sampled between -0.1 and 0.1. ... >>> init = mx.init.Uniform(0.1) >>> module.initialize(init) >>> for dictionary in module.get_params(): ... for key in dictionary: ... print(key) ... print(dictionary[key].asnumpy()) ... fullyconnected0_weight [[ 0.01360891 -0.02144304 0.08511933]] """ def __init__(self, scale=0.07): super(Uniform, self).__init__(scale=scale) self.scale = scale def _init_weight(self, _, arr): uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform uniform_fn(-self.scale, self.scale, arr.shape, dtype=arr.dtype, out=arr)
[docs]@register class Normal(Initializer): """Initializes weights with random values sampled from a normal distribution with a mean of zero and standard deviation of `sigma`. Parameters ---------- sigma : float, optional Standard deviation of the normal distribution. Default standard deviation is 0.01. Example ------- >>> # Given 'block', an instance of 'mxnet.gluon.Block', initialize weights >>> # to random values sampled from a normal distribution. ... >>> init = mx.init.Normal(0.5) >>> module.initialize(init) >>> for dictionary in module.get_params(): ... for key in dictionary: ... print(key) ... print(dictionary[key].asnumpy()) ... fullyconnected0_weight [[-0.3214761 -0.12660924 0.53789419]] """ def __init__(self, sigma=0.01): super(Normal, self).__init__(sigma=sigma) self.sigma = sigma def _init_weight(self, _, arr): normal_fn = _mx_np.random.normal if is_np_array() else random.normal normal_fn(0, self.sigma, arr.shape, dtype=arr.dtype, out=arr)
[docs]@register class Orthogonal(Initializer): """Initialize weight as orthogonal matrix. This initializer implements *Exact solutions to the nonlinear dynamics of learning in deep linear neural networks*, available at https://arxiv.org/abs/1312.6120. Parameters ---------- scale : float optional Scaling factor of weight. rand_type: string optional Use "uniform" or "normal" random number to initialize weight. """ def __init__(self, scale=1.414, rand_type="uniform"): super(Orthogonal, self).__init__(scale=scale, rand_type=rand_type) self.scale = scale self.rand_type = rand_type def _init_weight(self, _, arr): nout = arr.shape[0] nin = np.prod(arr.shape[1:]) if self.rand_type == "uniform": tmp = random.uniform(-1.0, 1.0, shape=(nout, nin)).asnumpy() elif self.rand_type == "normal": tmp = random.normal(0.0, 1.0, shape=(nout, nin)).asnumpy() u, _, v = np.linalg.svd(tmp, full_matrices=False) # pylint: disable=invalid-name if u.shape == tmp.shape: res = u else: res = v res = self.scale * res.reshape(arr.shape) arr[:] = res
[docs]@register class Xavier(Initializer): """Returns an initializer performing "Xavier" initialization for weights. This initializer is designed to keep the scale of gradients roughly the same in all layers. By default, `rnd_type` is ``'uniform'`` and `factor_type` is ``'avg'``, the initializer fills the weights with random numbers in the range of :math:`[-c, c]`, where :math:`c = \\sqrt{\\frac{3.}{0.5 * (n_{in} + n_{out})}}`. :math:`n_{in}` is the number of neurons feeding into weights, and :math:`n_{out}` is the number of neurons the result is fed to. If `rnd_type` is ``'uniform'`` and `factor_type` is ``'in'``, the :math:`c = \\sqrt{\\frac{3.}{n_{in}}}`. Similarly when `factor_type` is ``'out'``, the :math:`c = \\sqrt{\\frac{3.}{n_{out}}}`. If `rnd_type` is ``'gaussian'`` and `factor_type` is ``'avg'``, the initializer fills the weights with numbers from normal distribution with a standard deviation of :math:`\\sqrt{\\frac{3.}{0.5 * (n_{in} + n_{out})}}`. Parameters ---------- rnd_type: str, optional Random generator type, can be ``'gaussian'`` or ``'uniform'``. factor_type: str, optional Can be ``'avg'``, ``'in'``, or ``'out'``. magnitude: float, optional Scale of random number. """ def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): super(Xavier, self).__init__(rnd_type=rnd_type, factor_type=factor_type, magnitude=magnitude) self.rnd_type = rnd_type self.factor_type = factor_type self.magnitude = float(magnitude) def _init_weight(self, name, arr): shape = arr.shape hw_scale = 1. if len(shape) < 2: raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at' ' least 2D.'.format(name)) if len(shape) > 2: hw_scale = np.prod(shape[2:]) fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale factor = 1. if self.factor_type == "avg": factor = (fan_in + fan_out) / 2.0 elif self.factor_type == "in": factor = fan_in elif self.factor_type == "out": factor = fan_out else: raise ValueError("Incorrect factor type") scale = np.sqrt(self.magnitude / factor) if self.rnd_type == "uniform": uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform uniform_fn(-scale, scale, arr.shape, dtype=arr.dtype, out=arr) elif self.rnd_type == "gaussian": normal_fn = _mx_np.random.normal if is_np_array() else random.normal normal_fn(0, scale, arr.shape, dtype=arr.dtype, out=arr) else: raise ValueError("Unknown random type")
[docs]@register class MSRAPrelu(Xavier): """Initialize the weight according to a MSRA paper. This initializer implements *Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification*, available at https://arxiv.org/abs/1502.01852. This initializer is proposed for initialization related to ReLu activation, it makes some changes on top of Xavier method. Parameters ---------- factor_type: str, optional Can be ``'avg'``, ``'in'``, or ``'out'``. slope: float, optional initial slope of any PReLU (or similar) nonlinearities. """ def __init__(self, factor_type="avg", slope=0.25): magnitude = 2. / (1 + slope ** 2) super(MSRAPrelu, self).__init__("gaussian", factor_type, magnitude) self._kwargs = {'factor_type': factor_type, 'slope': slope}
[docs]@register class Bilinear(Initializer): """Initialize weight for upsampling layers.""" def __init__(self): super(Bilinear, self).__init__() def _init_weight(self, _, arr): weight = np.zeros(np.prod(arr.shape), dtype='float32') shape = arr.shape f = np.ceil(shape[3] / 2.) c = (2 * f - 1 - f % 2) / (2. * f) for i in range(np.prod(shape)): x = i % shape[3] y = (i // shape[3]) % shape[2] weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) arr[:] = weight.reshape(shape)
[docs]@register class LSTMBias(Initializer): """Initialize all biases of an LSTMCell to 0.0 except for the forget gate whose bias is set to custom value. Parameters ---------- forget_bias: float, default 1.0 bias for the forget gate. Jozefowicz et al. 2015 recommends setting this to 1.0. """ def __init__(self, forget_bias=1.0): super(LSTMBias, self).__init__(forget_bias=forget_bias) self.forget_bias = forget_bias def _init_weight(self, name, arr): arr[:] = 0.0 # in the case of LSTMCell the forget gate is the second # gate of the 4 LSTM gates, we modify the according values. num_hidden = int(arr.shape[0] / 4) arr[num_hidden:2*num_hidden] = self.forget_bias
[docs]@register class RNNFused(Initializer): """Initialize RNN fused parameter with bias part initialized to 0.0 and weight initialized with random values uniformly sampled from a given range. Parameters ---------- mode : {'gru', 'lstm', 'rnn_relu', 'rnn_tanh'}, required the type of RNN to compute num_layers : int (non-negative), required number of stacked layers state_size : int (non-negative), required size of the state for each layer bidirectional : boolean, optional, default=0 whether to use bidirectional recurrent layers projection_size : int or None, optional, default='None' size of project size scale : float, optional The bound on the range of the generated random values for weights. Values are generated from the range [-`scale`, `scale`]. Default scale is 0.07. """ def __init__(self, mode, num_layers, state_size, bidirectional=False, projection_size=None, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer=None, h2h_bias_initializer=None, h2r_weight_initializer=None): super(RNNFused, self).__init__(mode=mode, num_layers=num_layers, state_size=state_size, bidirectional=bidirectional, projection_size=projection_size, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, h2h_bias_initializer=h2h_bias_initializer, h2r_weight_initializer=h2r_weight_initializer) self.gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] self.num_layers = num_layers self.num_hidden = state_size self.dir = 2 if bidirectional else 1 self.projection_size = projection_size self._i2h_weight_initializer = i2h_weight_initializer self._h2h_weight_initializer = h2h_weight_initializer self._i2h_bias_initializer = i2h_bias_initializer self._h2h_bias_initializer = h2h_bias_initializer self._h2r_weight_initializer = h2r_weight_initializer # pylint: disable=too-many-nested-blocks def _init_weight(self, name, arr): arr_len = arr.shape[0] size = self.num_hidden * self.dir * self.gates if not self.projection_size: # second layer size size2 = (self.num_hidden * self.dir + self.num_hidden + 2) * size input_size = (arr_len - (self.num_layers - 1) * size2) // \ size - 2 - self.num_hidden else: # second layer size size2 = (self.projection_size * self.dir + self.projection_size + 2) * size size_projection = self.projection_size * self.num_hidden * self.num_layers * self.dir input_size = (arr_len - size_projection - (self.num_layers - 1) * size2) // \ size - 2 - self.projection_size begin = 0 if not self.projection_size: for param in ['weight', 'bias']: for layer_num in range(self.num_layers): for _ in range(self.dir): for connect in ['i2h', 'h2h']: num_inputs = input_size if layer_num != 0: num_inputs = self.num_hidden * self.dir if connect == 'h2h': num_inputs = self.num_hidden shape0 = self.gates * self.num_hidden if param == 'weight': cur_len = shape0 * num_inputs else: cur_len = shape0 self._init_util(param, connect, arr[begin:begin+cur_len]) begin += cur_len else: for param in ['weight', 'bias']: for layer_num in range(self.num_layers): for _ in range(self.dir): for connect in ['i2h', 'h2h', 'h2r']: if connect != 'h2r' or param != 'bias': if connect == 'h2r': cur_len = self.projection_size * self.num_hidden else: num_inputs = input_size if layer_num != 0: num_inputs = self.projection_size * self.dir if connect == 'h2h': num_inputs = self.projection_size shape0 = self.gates * self.num_hidden if param == 'weight': cur_len = shape0 * num_inputs else: cur_len = shape0 self._init_util(param, connect, arr[begin:begin+cur_len]) begin += cur_len def _init_util(self, param, connect, arr): name = "_{}_{}_initializer".format(connect, param) init = getattr(self, name) create(init)(InitDesc(name, {'__init__': init}), arr) def set_initializer(self, init): self._i2h_weight_initializer = \ init if not self._i2h_weight_initializer else 'uniform' self._h2h_weight_initializer = \ init if not self._h2h_weight_initializer else 'uniform' self._i2h_bias_initializer = \ init if not self._i2h_bias_initializer else 'zero' self._h2h_bias_initializer = \ init if not self._i2h_bias_initializer else 'zero' self._h2r_weight_initializer = \ init if not self._h2r_weight_initializer else 'uniform'