gluon.nn

Gluon provides a large number of build-in neural network layers in the following two modules:

mxnet.gluon.nn

Neural network layers.

mxnet.gluon.contrib.nn

Contributed neural network modules.

We group all layers in these two modules according to their categories.

Sequential containers

nn.Sequential

Stacks Blocks sequentially.

nn.HybridSequential

Stacks HybridBlocks sequentially.

Basic Layers

nn.Dense

Just your regular densely-connected NN layer.

nn.Activation

Applies an activation function to input.

nn.Dropout

Applies Dropout to the input.

nn.Flatten

Flattens the input to two dimensional.

nn.Lambda

Wraps an operator or an expression as a Block object.

nn.HybridLambda

Wraps an operator or an expression as a HybridBlock object.

Convolutional Layers

nn.Conv1D

1D convolution layer (e.g.

nn.Conv2D

2D convolution layer (e.g.

nn.Conv3D

3D convolution layer (e.g.

nn.Conv1DTranspose

Transposed 1D convolution layer (sometimes called Deconvolution).

nn.Conv2DTranspose

Transposed 2D convolution layer (sometimes called Deconvolution).

nn.Conv3DTranspose

Transposed 3D convolution layer (sometimes called Deconvolution).

Pooling Layers

nn.MaxPool1D

Max pooling operation for one dimensional data.

nn.MaxPool2D

Max pooling operation for two dimensional (spatial) data.

nn.MaxPool3D

Max pooling operation for 3D data (spatial or spatio-temporal).

nn.AvgPool1D

Average pooling operation for temporal data.

nn.AvgPool2D

Average pooling operation for spatial data.

nn.AvgPool3D

Average pooling operation for 3D data (spatial or spatio-temporal).

nn.GlobalMaxPool1D

Gloabl max pooling operation for one dimensional (temporal) data.

nn.GlobalMaxPool2D

Global max pooling operation for two dimensional (spatial) data.

nn.GlobalMaxPool3D

Global max pooling operation for 3D data (spatial or spatio-temporal).

nn.GlobalAvgPool1D

Global average pooling operation for temporal data.

nn.GlobalAvgPool2D

Global average pooling operation for spatial data.

nn.GlobalAvgPool3D

Global average pooling operation for 3D data (spatial or spatio-temporal).

nn.ReflectionPad2D

Pads the input tensor using the reflection of the input boundary.

Normalization Layers

nn.BatchNorm

Batch normalization layer (Ioffe and Szegedy, 2014).

nn.InstanceNorm

Applies instance normalization to the n-dimensional input array.

nn.LayerNorm

Applies layer normalization to the n-dimensional input array.

Embedding Layers

nn.Embedding

Turns non-negative integers (indexes/tokens) into dense vectors of fixed size.

Advanced Activation Layers

nn.LeakyReLU

Leaky version of a Rectified Linear Unit.

nn.PReLU

Parametric leaky version of a Rectified Linear Unit.

nn.ELU

Exponential Linear Unit (ELU)

nn.SELU

Scaled Exponential Linear Unit (SELU)

nn.Swish

Swish Activation function

API Reference

Neural network layers.

Classes

Activation(activation, **kwargs)

Applies an activation function to input.

AvgPool1D([pool_size, strides, padding, …])

Average pooling operation for temporal data.

AvgPool2D([pool_size, strides, padding, …])

Average pooling operation for spatial data.

AvgPool3D([pool_size, strides, padding, …])

Average pooling operation for 3D data (spatial or spatio-temporal).

BatchNorm([axis, momentum, epsilon, center, …])

Batch normalization layer (Ioffe and Szegedy, 2014).

Block([prefix, params])

Base class for all neural network layers and models.

Conv1D(channels, kernel_size[, strides, …])

1D convolution layer (e.g. temporal convolution).

Conv1DTranspose(channels, kernel_size[, …])

Transposed 1D convolution layer (sometimes called Deconvolution).

Conv2D(channels, kernel_size[, strides, …])

2D convolution layer (e.g. spatial convolution over images).

Conv2DTranspose(channels, kernel_size[, …])

Transposed 2D convolution layer (sometimes called Deconvolution).

Conv3D(channels, kernel_size[, strides, …])

3D convolution layer (e.g. spatial convolution over volumes).

Conv3DTranspose(channels, kernel_size[, …])

Transposed 3D convolution layer (sometimes called Deconvolution).

Dense(units[, activation, use_bias, …])

Just your regular densely-connected NN layer.

Dropout(rate[, axes])

Applies Dropout to the input.

ELU([alpha])

Exponential Linear Unit (ELU)

Embedding(input_dim, output_dim[, dtype, …])

Turns non-negative integers (indexes/tokens) into dense vectors of fixed size.

Flatten(**kwargs)

Flattens the input to two dimensional.

GELU(**kwargs)

Gaussian Exponential Linear Unit (GELU)

GlobalAvgPool1D([layout])

Global average pooling operation for temporal data.

GlobalAvgPool2D([layout])

Global average pooling operation for spatial data.

GlobalAvgPool3D([layout])

Global average pooling operation for 3D data (spatial or spatio-temporal).

GlobalMaxPool1D([layout])

Gloabl max pooling operation for one dimensional (temporal) data.

GlobalMaxPool2D([layout])

Global max pooling operation for two dimensional (spatial) data.

GlobalMaxPool3D([layout])

Global max pooling operation for 3D data (spatial or spatio-temporal).

GroupNorm([num_groups, epsilon, center, …])

Applies group normalization to the n-dimensional input array.

HybridBlock([prefix, params])

HybridBlock supports forwarding with both Symbol and NDArray.

HybridLambda(function[, prefix])

Wraps an operator or an expression as a HybridBlock object.

HybridSequential([prefix, params])

Stacks HybridBlocks sequentially.

InstanceNorm([axis, epsilon, center, scale, …])

Applies instance normalization to the n-dimensional input array.

Lambda(function[, prefix])

Wraps an operator or an expression as a Block object.

LayerNorm([axis, epsilon, center, scale, …])

Applies layer normalization to the n-dimensional input array.

LeakyReLU(alpha, **kwargs)

Leaky version of a Rectified Linear Unit.

MaxPool1D([pool_size, strides, padding, …])

Max pooling operation for one dimensional data.

MaxPool2D([pool_size, strides, padding, …])

Max pooling operation for two dimensional (spatial) data.

MaxPool3D([pool_size, strides, padding, …])

Max pooling operation for 3D data (spatial or spatio-temporal).

PReLU([alpha_initializer, in_channels])

Parametric leaky version of a Rectified Linear Unit.

ReflectionPad2D([padding])

Pads the input tensor using the reflection of the input boundary.

SELU(**kwargs)

Scaled Exponential Linear Unit (SELU)

Sequential([prefix, params])

Stacks Blocks sequentially.

Swish([beta])

Swish Activation function

SymbolBlock(outputs, inputs[, params])

Construct block from symbol.

class mxnet.gluon.nn.Activation(activation, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Applies an activation function to input.

Parameters

activation (str) – Name of activation function to use. See Activation() for available choices.

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.AvgPool1D(pool_size=2, strides=None, padding=0, layout='NCW', ceil_mode=False, count_include_pad=True, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Average pooling operation for temporal data.

Parameters
  • pool_size (int) – Size of the average pooling windows.

  • strides (int, or None) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCW') – Dimension ordering of data and out (‘NCW’ or ‘NWC’). ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. padding is applied on ‘W’ dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad (bool, default True) – When ‘False’, will exclude padding elements when computing the average value.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, out_width) when layout is NCW. out_width is calculated as:

    out_width = floor((width+2*padding-pool_size)/strides)+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.AvgPool2D(pool_size=(2, 2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Average pooling operation for spatial data.

Parameters
  • pool_size (int or list/tuple of 2 ints,) – Size of the average pooling windows.

  • strides (int, list/tuple of 2 ints, or None.) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int or list/tuple of 2 ints,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCHW') – Dimension ordering of data and out (‘NCHW’ or ‘NHWC’). ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively. padding is applied on ‘H’ and ‘W’ dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad (bool, default True) – When ‘False’, will exclude padding elements when computing the average value.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, out_height, out_width) when layout is NCHW. out_height and out_width are calculated as:

    out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
    out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.AvgPool3D(pool_size=(2, 2, 2), strides=None, padding=0, ceil_mode=False, layout='NCDHW', count_include_pad=True, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Average pooling operation for 3D data (spatial or spatio-temporal).

Parameters
  • pool_size (int or list/tuple of 3 ints,) – Size of the average pooling windows.

  • strides (int, list/tuple of 3 ints, or None.) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int or list/tuple of 3 ints,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCDHW') – Dimension ordering of data and out (‘NCDHW’ or ‘NDHWC’). ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. padding is applied on ‘D’, ‘H’ and ‘W’ dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

  • count_include_pad (bool, default True) – When ‘False’, will exclude padding elements when computing the average value.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCDHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, out_depth, out_height, out_width) when layout is NCDHW. out_depth, out_height and out_width are calculated as:

    out_depth = floor((depth+2*padding[0]-pool_size[0])/strides[0])+1
    out_height = floor((height+2*padding[1]-pool_size[1])/strides[1])+1
    out_width = floor((width+2*padding[2]-pool_size[2])/strides[2])+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.BatchNorm(axis=1, momentum=0.9, epsilon=1e-05, center=True, scale=True, use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones', running_mean_initializer='zeros', running_variance_initializer='ones', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1.

Parameters
  • axis (int, default 1) – The axis that should be normalized. This is typically the channels (C) axis. For instance, after a Conv2D layer with layout=’NCHW’, set axis=1 in BatchNorm. If layout=’NHWC’, then set axis=3.

  • momentum (float, default 0.9) – Momentum for the moving average.

  • epsilon (float, default 1e-5) – Small float added to variance to avoid dividing by zero.

  • center (bool, default True) – If True, add offset of beta to normalized tensor. If False, beta is ignored.

  • scale (bool, default True) – If True, multiply by gamma. If False, gamma is not used. When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • use_global_stats (bool, default False) – If True, use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. If False, use local batch-norm.

  • beta_initializer (str or Initializer, default ‘zeros’) – Initializer for the beta weight.

  • gamma_initializer (str or Initializer, default ‘ones’) – Initializer for the gamma weight.

  • running_mean_initializer (str or Initializer, default ‘zeros’) – Initializer for the running mean.

  • running_variance_initializer (str or Initializer, default ‘ones’) – Initializer for the running variance.

  • in_channels (int, default 0) – Number of channels (feature maps) in input data. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Methods

cast(dtype)

Cast this Block to use another data type.

hybrid_forward(F, x, gamma, beta, …)

Overrides to construct symbolic graph for this Block.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

cast(dtype)[source]

Cast this Block to use another data type.

Parameters

dtype (str or numpy.dtype) – The new data type.

hybrid_forward(F, x, gamma, beta, running_mean, running_var)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Block(prefix=None, params=None)[source]

Bases: object

Base class for all neural network layers and models. Your models should subclass this class.

Block can be nested recursively in a tree structure. You can create and assign child Block as regular attributes:

from mxnet.gluon import Block, nn
from mxnet import ndarray as F

class Model(Block):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        # use name_scope to give child Blocks appropriate names.
        with self.name_scope():
            self.dense0 = nn.Dense(20)
            self.dense1 = nn.Dense(20)

    def forward(self, x):
        x = F.relu(self.dense0(x))
        return F.relu(self.dense1(x))

model = Model()
model.initialize(ctx=mx.cpu(0))
model(F.zeros((10, 10), ctx=mx.cpu(0)))

Methods

apply(fn)

Applies fn recursively to every child block as well as self.

cast(dtype)

Cast this Block to use another data type.

collect_params([select])

Returns a ParameterDict containing this Block and all of its children’s Parameters(default), also can returns the select ParameterDict which match some given regular expressions.

forward(*args)

Overrides to implement forward computation using NDArray.

hybridize([active])

Please refer description of HybridBlock hybridize().

initialize([init, ctx, verbose, force_reinit])

Initializes Parameter s of this Block and its children.

load_parameters(filename[, ctx, …])

Load parameters from file previously saved by save_parameters.

load_params(filename[, ctx, allow_missing, …])

[Deprecated] Please use load_parameters.

name_scope()

Returns a name space object managing a child Block and parameter names.

register_child(block[, name])

Registers block as a child of self.

register_forward_hook(hook)

Registers a forward hook on the block.

register_forward_pre_hook(hook)

Registers a forward pre-hook on the block.

register_op_hook(callback[, monitor_all])

Install callback monitor.

save_parameters(filename[, deduplicate])

Save parameters to file.

save_params(filename)

[Deprecated] Please use save_parameters. Note that if you want load

summary(*inputs)

Print the summary of the model’s output and parameters.

Attributes

name

Name of this Block, without ‘_’ in the end.

params

Returns this Block’s parameter dictionary (does not include its children’s parameters).

prefix

Prefix of this Block.

Child Block assigned this way will be registered and collect_params() will collect their Parameters recursively. You can also manually register child blocks with register_child().

Parameters
  • prefix (str) – Prefix acts like a name space. All children blocks created in parent block’s name_scope() will have parent block’s prefix in their name. Please refer to naming tutorial for more info on prefix and naming.

  • params (ParameterDict or None) –

    ParameterDict for sharing weights with the new Block. For example, if you want dense1 to share dense0’s weights, you can do:

    dense0 = nn.Dense(20)
    dense1 = nn.Dense(20, params=dense0.collect_params())
    

apply(fn)[source]

Applies fn recursively to every child block as well as self.

Parameters

fn (callable) – Function to be applied to each submodule, of form fn(block).

Returns

Return type

this block

cast(dtype)[source]

Cast this Block to use another data type.

Parameters

dtype (str or numpy.dtype) – The new data type.

collect_params(select=None)[source]

Returns a ParameterDict containing this Block and all of its children’s Parameters(default), also can returns the select ParameterDict which match some given regular expressions.

For example, collect the specified parameters in [‘conv1_weight’, ‘conv1_bias’, ‘fc_weight’, ‘fc_bias’]:

model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')

or collect all parameters whose names end with ‘weight’ or ‘bias’, this can be done using regular expressions:

model.collect_params('.*weight|.*bias')
Parameters

select (str) – regular expressions

Returns

Return type

The selected ParameterDict

forward(*args)[source]

Overrides to implement forward computation using NDArray. Only accepts positional arguments.

Parameters

*args (list of NDArray) – Input tensors.

hybridize(active=True, **kwargs)[source]

Please refer description of HybridBlock hybridize().

initialize(init=<mxnet.initializer.Uniform object>, ctx=None, verbose=False, force_reinit=False)[source]

Initializes Parameter s of this Block and its children. Equivalent to block.collect_params().initialize(...)

Parameters
  • init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.

  • ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).

  • verbose (bool, default False) – Whether to verbosely print out details on initialization.

  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.

load_parameters(filename, ctx=None, allow_missing=False, ignore_extra=False, cast_dtype=False, dtype_source='current')[source]

Load parameters from file previously saved by save_parameters.

Parameters
  • filename (str) – Path to parameter file.

  • ctx (Context or list of Context, default cpu()) – Context(s) to initialize loaded parameters on.

  • allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.

  • ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this Block.

  • cast_dtype (bool, default False) – Cast the data type of the NDArray loaded from the checkpoint to the dtype provided by the Parameter if any.

  • dtype_source (str, default 'current') – must be in {‘current’, ‘saved’} Only valid if cast_dtype=True, specify the source of the dtype for casting the parameters

References

Saving and Loading Gluon Models

load_params(filename, ctx=None, allow_missing=False, ignore_extra=False)[source]

[Deprecated] Please use load_parameters.

Load parameters from file.

filenamestr

Path to parameter file.

ctxContext or list of Context, default cpu()

Context(s) to initialize loaded parameters on.

allow_missingbool, default False

Whether to silently skip loading parameters not represents in the file.

ignore_extrabool, default False

Whether to silently ignore parameters from the file that are not present in this Block.

property name

Name of this Block, without ‘_’ in the end.

name_scope()[source]

Returns a name space object managing a child Block and parameter names. Should be used within a with statement:

with self.name_scope():
    self.dense = nn.Dense(20)

Please refer to the naming tutorial for more info on prefix and naming.

property params

Returns this Block’s parameter dictionary (does not include its children’s parameters).

property prefix

Prefix of this Block.

register_child(block, name=None)[source]

Registers block as a child of self. Block s assigned to self as attributes will be registered automatically.

register_forward_hook(hook)[source]

Registers a forward hook on the block.

The hook function is called immediately after forward(). It should not modify the input or output.

Parameters

hook (callable) – The forward hook function of form hook(block, input, output) -> None.

Returns

Return type

mxnet.gluon.utils.HookHandle

register_forward_pre_hook(hook)[source]

Registers a forward pre-hook on the block.

The hook function is called immediately before forward(). It should not modify the input or output.

Parameters

hook (callable) – The forward hook function of form hook(block, input) -> None.

Returns

Return type

mxnet.gluon.utils.HookHandle

register_op_hook(callback, monitor_all=False)[source]

Install callback monitor.

Parameters
  • callback (function) – Takes a string and a NDArrayHandle.

  • monitor_all (bool, default False) – If true, monitor both input and output, otherwise monitor output only.

save_parameters(filename, deduplicate=False)[source]

Save parameters to file.

Saved parameters can only be loaded with load_parameters. Note that this method only saves parameters, not model structure. If you want to save model structures, please use HybridBlock.export().

Parameters
  • filename (str) – Path to file.

  • deduplicate (bool, default False) – If True, save shared parameters only once. Otherwise, if a Block contains multiple sub-blocks that share parameters, each of the shared parameters will be separately saved for every sub-block.

References

Saving and Loading Gluon Models

save_params(filename)[source]

[Deprecated] Please use save_parameters. Note that if you want load from SymbolBlock later, please use export instead.

Save parameters to file.

filenamestr

Path to file.

summary(*inputs)[source]

Print the summary of the model’s output and parameters.

The network must have been initialized, and must not have been hybridized.

Parameters

inputs (object) – Any input that the model supports. For any tensor in the input, only mxnet.ndarray.NDArray is supported.

class mxnet.gluon.nn.Conv1D(channels, kernel_size, strides=1, padding=0, dilation=1, groups=1, layout='NCW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

1D convolution layer (e.g. temporal convolution).

This layer creates a convolution kernel that is convolved with the layer input over a single spatial (or temporal) dimension to produce a tensor of outputs. If use_bias is True, a bias vector is created and added to the outputs. Finally, if activation is not None, it is applied to the outputs as well.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 1 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 1 int,) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 1 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • dilation (int or tuple/list of 1 int) – Specifies the dilation rate to use for dilated convolution.

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCW') – Dimension ordering of data and weight. Only supports ‘NCW’ layout for now. ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. Convolution is applied on the ‘W’ dimension.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, out_width) when layout is NCW. out_width is calculated as:

    out_width = floor((width+2*padding-dilation*(kernel_size-1)-1)/stride)+1
    
class mxnet.gluon.nn.Conv1DTranspose(channels, kernel_size, strides=1, padding=0, output_padding=0, dilation=1, groups=1, layout='NCW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

Transposed 1D convolution layer (sometimes called Deconvolution).

The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 1 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 1 int) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 1 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • output_padding (int or a tuple/list of 1 int) – Controls the amount of implicit zero-paddings on both sides of the output for output_padding number of points for each dimension.

  • dilation (int or tuple/list of 1 int) – Controls the spacing between the kernel points; also known as the a trous algorithm

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCW') – Dimension ordering of data and weight. Only supports ‘NCW’ layout for now. ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. Convolution is applied on the ‘W’ dimension.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, out_width) when layout is NCW. out_width is calculated as:

    out_width = (width-1)*strides-2*padding+kernel_size+output_padding
    
class mxnet.gluon.nn.Conv2D(channels, kernel_size, strides=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, layout='NCHW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

2D convolution layer (e.g. spatial convolution over images).

This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. If use_bias is True, a bias vector is created and added to the outputs. Finally, if activation is not None, it is applied to the outputs as well.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 2 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 2 int,) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 2 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • dilation (int or tuple/list of 2 int) – Specifies the dilation rate to use for dilated convolution.

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCHW') – Dimension ordering of data and weight. Only supports ‘NCHW’ and ‘NHWC’ layout for now. ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively. Convolution is applied on the ‘H’ and ‘W’ dimensions.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, out_height, out_width) when layout is NCHW. out_height and out_width are calculated as:

    out_height = floor((height+2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)/stride[0])+1
    out_width = floor((width+2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)/stride[1])+1
    
class mxnet.gluon.nn.Conv2DTranspose(channels, kernel_size, strides=(1, 1), padding=(0, 0), output_padding=(0, 0), dilation=(1, 1), groups=1, layout='NCHW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

Transposed 2D convolution layer (sometimes called Deconvolution).

The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 2 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 2 int) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 2 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • output_padding (int or a tuple/list of 2 int) – Controls the amount of implicit zero-paddings on both sides of the output for output_padding number of points for each dimension.

  • dilation (int or tuple/list of 2 int) – Controls the spacing between the kernel points; also known as the a trous algorithm

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCHW') – Dimension ordering of data and weight. Only supports ‘NCHW’ and ‘NHWC’ layout for now. ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively. Convolution is applied on the ‘H’ and ‘W’ dimensions.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, out_height, out_width) when layout is NCHW. out_height and out_width are calculated as:

    out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
    out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
    
class mxnet.gluon.nn.Conv3D(channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, layout='NCDHW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

3D convolution layer (e.g. spatial convolution over volumes).

This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. If use_bias is True, a bias vector is created and added to the outputs. Finally, if activation is not None, it is applied to the outputs as well.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 3 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 3 int,) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 3 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • dilation (int or tuple/list of 3 int) – Specifies the dilation rate to use for dilated convolution.

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCDHW') – Dimension ordering of data and weight. Only supports ‘NCDHW’ and ‘NDHWC’ layout for now. ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. Convolution is applied on the ‘D’, ‘H’ and ‘W’ dimensions.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCDHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, out_depth, out_height, out_width) when layout is NCDHW. out_depth, out_height and out_width are calculated as:

    out_depth = floor((depth+2*padding[0]-dilation[0]*(kernel_size[0]-1)-1)/stride[0])+1
    out_height = floor((height+2*padding[1]-dilation[1]*(kernel_size[1]-1)-1)/stride[1])+1
    out_width = floor((width+2*padding[2]-dilation[2]*(kernel_size[2]-1)-1)/stride[2])+1
    
class mxnet.gluon.nn.Conv3DTranspose(channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), output_padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, layout='NCDHW', activation=None, use_bias=True, weight_initializer=None, bias_initializer='zeros', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Conv

Transposed 3D convolution layer (sometimes called Deconvolution).

The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.

If in_channels is not specified, Parameter initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Parameters
  • channels (int) – The dimensionality of the output space, i.e. the number of output channels (filters) in the convolution.

  • kernel_size (int or tuple/list of 3 int) – Specifies the dimensions of the convolution window.

  • strides (int or tuple/list of 3 int) – Specify the strides of the convolution.

  • padding (int or a tuple/list of 3 int,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points

  • output_padding (int or a tuple/list of 3 int) – Controls the amount of implicit zero-paddings on both sides of the output for output_padding number of points for each dimension.

  • dilation (int or tuple/list of 3 int) – Controls the spacing between the kernel points; also known as the a trous algorithm.

  • groups (int) – Controls the connections between inputs and outputs. At groups=1, all inputs are convolved to all outputs. At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels, and producing half the output channels, and both subsequently concatenated.

  • layout (str, default 'NCDHW') – Dimension ordering of data and weight. Only supports ‘NCDHW’ and ‘NDHWC’ layout for now. ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. Convolution is applied on the ‘D’, ‘H’ and ‘W’ dimensions.

  • in_channels (int, default 0) – The number of input channels to this layer. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

  • activation (str) – Activation function to use. See Activation(). If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool) – Whether the layer uses a bias vector.

  • weight_initializer (str or Initializer) – Initializer for the weight weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCDHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, out_depth, out_height, out_width) when layout is NCDHW. out_depth, out_height and out_width are calculated as:

    out_depth = (depth-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
    out_height = (height-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
    out_width = (width-1)*strides[2]-2*padding[2]+kernel_size[2]+output_padding[2]
    
class mxnet.gluon.nn.Dense(units, activation=None, use_bias=True, flatten=True, dtype='float32', weight_initializer=None, bias_initializer='zeros', in_units=0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Just your regular densely-connected NN layer.

Dense implements the operation: output = activation(dot(input, weight) + bias) where activation is the element-wise activation function passed as the activation argument, weight is a weights matrix created by the layer, and bias is a bias vector created by the layer (only applicable if use_bias is True).

Note

the input must be a tensor with rank 2. Use flatten to convert it to rank 2 manually if necessary.

Methods

hybrid_forward(F, x, weight[, bias])

Overrides to construct symbolic graph for this Block.

Parameters
  • units (int) – Dimensionality of the output space.

  • activation (str) – Activation function to use. See help on Activation layer. If you don’t specify anything, no activation is applied (ie. “linear” activation: a(x) = x).

  • use_bias (bool, default True) – Whether the layer uses a bias vector.

  • flatten (bool, default True) – Whether the input tensor should be flattened. If true, all but the first axis of input data are collapsed together. If false, all but the last axis of input data are kept the same, and the transformation applies on the last axis.

  • dtype (str or np.dtype, default 'float32') – Data type of output embeddings.

  • weight_initializer (str or Initializer) – Initializer for the kernel weights matrix.

  • bias_initializer (str or Initializer) – Initializer for the bias vector.

  • in_units (int, optional) – Size of the input data. If not specified, initialization will be deferred to the first time forward is called and in_units will be inferred from the shape of input data.

  • prefix (str or None) – See document of Block.

  • params (ParameterDict or None) – See document of Block.

Inputs:
  • data: if flatten is True, data should be a tensor with shape (batch_size, x1, x2, …, xn), where x1 * x2 * … * xn is equal to in_units. If flatten is False, data should have shape (x1, x2, …, xn, in_units).

Outputs:
  • out: if flatten is True, out will be a tensor with shape (batch_size, units). If flatten is False, out will have shape (x1, x2, …, xn, units).

hybrid_forward(F, x, weight, bias=None)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Dropout(rate, axes=(), **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Applies Dropout to the input.

Dropout consists in randomly setting a fraction rate of input units to 0 at each update during training time, which helps prevent overfitting.

Parameters
  • rate (float) – Fraction of the input units to drop. Must be a number between 0 and 1.

  • axes (tuple of int, default ()) – The axes on which dropout mask is shared. If empty, regular dropout is applied.

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

References

Dropout: A Simple Way to Prevent Neural Networks from Overfitting

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.ELU(alpha=1.0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Exponential Linear Unit (ELU)

“Fast and Accurate Deep Network Learning by Exponential Linear Units”, Clevert et al, 2016 https://arxiv.org/abs/1511.07289 Published as a conference paper at ICLR 2016

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Parameters

alpha (float) – The alpha parameter as described by Clevert et al, 2016

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Embedding(input_dim, output_dim, dtype='float32', weight_initializer=None, sparse_grad=False, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Turns non-negative integers (indexes/tokens) into dense vectors of fixed size. eg. [4, 20] -> [[0.25, 0.1], [0.6, -0.2]]

Note

if sparse_grad is set to True, the gradient w.r.t weight will be sparse. Only a subset of optimizers support sparse gradients, including SGD, AdaGrad and Adam. By default lazy updates is turned on, which may perform differently from standard updates. For more details, please check the Optimization API at: https://mxnet.incubator.apache.org/api/python/optimization/optimization.html

Methods

hybrid_forward(F, x, weight)

Overrides to construct symbolic graph for this Block.

Parameters
  • input_dim (int) – Size of the vocabulary, i.e. maximum integer index + 1.

  • output_dim (int) – Dimension of the dense embedding.

  • dtype (str or np.dtype, default 'float32') – Data type of output embeddings.

  • weight_initializer (Initializer) – Initializer for the embeddings matrix.

  • sparse_grad (bool) – If True, gradient w.r.t. weight will be a ‘row_sparse’ NDArray.

  • Inputs

    • data: (N-1)-D tensor with shape: (x1, x2, …, xN-1).

  • Output

    • out: N-D tensor with shape: (x1, x2, …, xN-1, output_dim).

hybrid_forward(F, x, weight)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Flatten(**kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Flattens the input to two dimensional.

Inputs:
  • data: input tensor with arbitrary shape (N, x1, x2, …, xn)

Output:
  • out: 2D tensor with shape: (N, x1 cdot x2 cdot … cdot xn)

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.GELU(**kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Gaussian Exponential Linear Unit (GELU)

“Gaussian Error Linear Units (GELUs)”, Hendrycks et al, 2016 https://arxiv.org/abs/1606.08415

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.GlobalAvgPool1D(layout='NCW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Global average pooling operation for temporal data.

Parameters

layout (str, default 'NCW') – Dimension ordering of data and out (‘NCW’ or ‘NWC’). ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. padding is applied on ‘W’ dimension.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, 1).

class mxnet.gluon.nn.GlobalAvgPool2D(layout='NCHW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Global average pooling operation for spatial data.

Parameters

layout (str, default 'NCHW') – Dimension ordering of data and out (‘NCHW’ or ‘NHWC’). ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, 1, 1) when layout is NCHW.

class mxnet.gluon.nn.GlobalAvgPool3D(layout='NCDHW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Global average pooling operation for 3D data (spatial or spatio-temporal).

Parameters

layout (str, default 'NCDHW') – Dimension ordering of data and out (‘NCDHW’ or ‘NDHWC’). ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. padding is applied on ‘D’, ‘H’ and ‘W’ dimension.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCDHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, 1, 1, 1) when layout is NCDHW.

class mxnet.gluon.nn.GlobalMaxPool1D(layout='NCW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Gloabl max pooling operation for one dimensional (temporal) data.

Parameters

layout (str, default 'NCW') – Dimension ordering of data and out (‘NCW’ or ‘NWC’). ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. Pooling is applied on the W dimension.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, 1) when layout is NCW.

class mxnet.gluon.nn.GlobalMaxPool2D(layout='NCHW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Global max pooling operation for two dimensional (spatial) data.

Parameters

layout (str, default 'NCHW') – Dimension ordering of data and out (‘NCHW’ or ‘NHWC’). ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively. padding is applied on ‘H’ and ‘W’ dimension.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, 1, 1) when layout is NCHW.

class mxnet.gluon.nn.GlobalMaxPool3D(layout='NCDHW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Global max pooling operation for 3D data (spatial or spatio-temporal).

Parameters

layout (str, default 'NCDHW') – Dimension ordering of data and out (‘NCDHW’ or ‘NDHWC’). ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. padding is applied on ‘D’, ‘H’ and ‘W’ dimension.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, 1, 1, 1) when layout is NCDHW.

class mxnet.gluon.nn.GroupNorm(num_groups=1, epsilon=1e-05, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', prefix=None, params=None)[source]

Bases: mxnet.gluon.block.HybridBlock

Applies group normalization to the n-dimensional input array. This operator takes an n-dimensional input array where the leftmost 2 axis are batch and channel respectively:

\[x = x.reshape((N, num_groups, C // num_groups, ...)) axis = (2, ...) out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta\]

Methods

hybrid_forward(F, data, gamma, beta)

Overrides to construct symbolic graph for this Block.

Parameters
  • num_groups (int, default 1) – Number of groups to separate the channel axis into.

  • epsilon (float, default 1e-5) – Small float added to variance to avoid dividing by zero.

  • center (bool, default True) – If True, add offset of beta to normalized tensor. If False, beta is ignored.

  • scale (bool, default True) – If True, multiply by gamma. If False, gamma is not used.

  • beta_initializer (str or Initializer, default ‘zeros’) – Initializer for the beta weight.

  • gamma_initializer (str or Initializer, default ‘ones’) – Initializer for the gamma weight.

Inputs:
  • data: input tensor with shape (N, C, …).

Outputs:
  • out: output tensor with the same shape as data.

References

Group Normalization

Examples

>>> # Input of shape (2, 3, 4)
>>> x = mx.nd.array([[[ 0,  1,  2,  3],
                      [ 4,  5,  6,  7],
                      [ 8,  9, 10, 11]],
                     [[12, 13, 14, 15],
                      [16, 17, 18, 19],
                      [20, 21, 22, 23]]])
>>> # Group normalization is calculated with the above formula
>>> layer = GroupNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
  [-0.4345239 -0.1448413  0.1448413  0.4345239]
  [ 0.7242065  1.0138891  1.3035717  1.5932543]]
 [[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
  [-0.4345239 -0.1448413  0.1448413  0.4345239]
  [ 0.7242065  1.0138891  1.3035717  1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
hybrid_forward(F, data, gamma, beta)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.HybridBlock(prefix=None, params=None)[source]

Bases: mxnet.gluon.block.Block

HybridBlock supports forwarding with both Symbol and NDArray.

HybridBlock is similar to Block, with a few differences:

import mxnet as mx
from mxnet.gluon import HybridBlock, nn

class Model(HybridBlock):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        # use name_scope to give child Blocks appropriate names.
        with self.name_scope():
            self.dense0 = nn.Dense(20)
            self.dense1 = nn.Dense(20)

    def hybrid_forward(self, F, x):
        x = F.relu(self.dense0(x))
        return F.relu(self.dense1(x))

model = Model()
model.initialize(ctx=mx.cpu(0))
model.hybridize()
model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))

Methods

cast(dtype)

Cast this Block to use another data type.

export(path[, epoch, remove_amp_cast])

Export HybridBlock to json format that can be loaded by gluon.SymbolBlock.imports, mxnet.mod.Module or the C++ interface.

forward(x, *args)

Defines the forward computation.

hybrid_forward(F, x, *args, **kwargs)

Overrides to construct symbolic graph for this Block.

hybridize([active, backend, backend_opts])

Activates or deactivates HybridBlock s recursively.

infer_shape(*args)

Infers shape of Parameters from inputs.

infer_type(*args)

Infers data type of Parameters from inputs.

optimize_for(x, *args[, backend, backend_opts])

Partitions the current HybridBlock and optimizes it for a given backend without executing a forward pass.

register_child(block[, name])

Registers block as a child of self.

register_op_hook(callback[, monitor_all])

Install op hook for block recursively.

Forward computation in HybridBlock must be static to work with Symbol s, i.e. you cannot call NDArray.asnumpy(), NDArray.shape, NDArray.dtype, NDArray indexing (x[i]) etc on tensors. Also, you cannot use branching or loop logic that bases on non-constant expressions like random numbers or intermediate results, since they change the graph structure for each iteration.

Before activating with hybridize(), HybridBlock works just like normal Block. After activation, HybridBlock will create a symbolic graph representing the forward computation and cache it. On subsequent forwards, the cached graph will be used instead of hybrid_forward().

Please see references for detailed tutorial.

References

Hybrid - Faster training and easy deployment

cast(dtype)[source]

Cast this Block to use another data type.

Parameters

dtype (str or numpy.dtype) – The new data type.

export(path, epoch=0, remove_amp_cast=True)[source]

Export HybridBlock to json format that can be loaded by gluon.SymbolBlock.imports, mxnet.mod.Module or the C++ interface.

Note

When there are only one input, it will have name data. When there Are more than one inputs, they will be named as data0, data1, etc.

Parameters
  • path (str) – Path to save model. Two files path-symbol.json and path-xxxx.params will be created, where xxxx is the 4 digits epoch number.

  • epoch (int) – Epoch number of saved model.

forward(x, *args)[source]

Defines the forward computation. Arguments can be either NDArray or Symbol.

hybrid_forward(F, x, *args, **kwargs)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

hybridize(active=True, backend=None, backend_opts=None, **kwargs)[source]

Activates or deactivates HybridBlock s recursively. Has no effect on non-hybrid children.

Parameters
  • active (bool, default True) – Whether to turn hybrid on or off.

  • backend (str) – The name of backend, as registered in SubgraphBackendRegistry, default None

  • backend_opts (dict of user-specified options to pass to the backend for partitioning, optional) – Passed on to PrePartition and PostPartition functions of SubgraphProperty

  • static_alloc (bool, default False) – Statically allocate memory to improve speed. Memory usage may increase.

  • static_shape (bool, default False) – Optimize for invariant input shapes between iterations. Must also set static_alloc to True. Change of input shapes is still allowed but slower.

infer_shape(*args)[source]

Infers shape of Parameters from inputs.

infer_type(*args)[source]

Infers data type of Parameters from inputs.

optimize_for(x, *args, backend=None, backend_opts=None, **kwargs)[source]

Partitions the current HybridBlock and optimizes it for a given backend without executing a forward pass. Modifies the HybridBlock in-place.

Immediately partitions a HybridBlock using the specified backend. Combines the work done in the hybridize API with part of the work done in the forward pass without calling the CachedOp. Can be used in place of hybridize, afterwards export can be called or inference can be run. See README.md in example/extensions/lib_subgraph/README.md for more details.

Examples

# partition and then export to file block.optimize_for(x, backend=’myPart’) block.export(‘partitioned’)

# partition and then run inference block.optimize_for(x, backend=’myPart’) block(x)

Parameters
  • x (NDArray) – first input to model

  • *args (NDArray) – other inputs to model

  • backend (str) – The name of backend, as registered in SubgraphBackendRegistry, default None

  • backend_opts (dict of user-specified options to pass to the backend for partitioning, optional) – Passed on to PrePartition and PostPartition functions of SubgraphProperty

  • static_alloc (bool, default False) – Statically allocate memory to improve speed. Memory usage may increase.

  • static_shape (bool, default False) – Optimize for invariant input shapes between iterations. Must also set static_alloc to True. Change of input shapes is still allowed but slower.

register_child(block, name=None)[source]

Registers block as a child of self. Block s assigned to self as attributes will be registered automatically.

register_op_hook(callback, monitor_all=False)[source]

Install op hook for block recursively.

Parameters
  • callback (function) – Takes a string and a NDArrayHandle.

  • monitor_all (bool, default False) – If true, monitor both input and output, otherwise monitor output only.

class mxnet.gluon.nn.HybridLambda(function, prefix=None)[source]

Bases: mxnet.gluon.block.HybridBlock

Wraps an operator or an expression as a HybridBlock object.

Parameters
  • function (str or function) –

    Function used in lambda must be one of the following: 1) The name of an operator that is available in both symbol and ndarray. For example:

    block = HybridLambda('tanh')
    
    1. A function that conforms to def function(F, data, *args). For example:

      block = HybridLambda(lambda F, x: F.LeakyReLU(x, slope=0.1))
      

  • Inputs

    • ** args *: one or more input data. First argument must be symbol or ndarray. Their

      shapes depend on the function.

  • Output

    • ** outputs *: one or more output data. Their shapes depend on the function.

Methods

hybrid_forward(F, x, *args)

Overrides to construct symbolic graph for this Block.

hybrid_forward(F, x, *args)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.HybridSequential(prefix=None, params=None)[source]

Bases: mxnet.gluon.block.HybridBlock

Stacks HybridBlocks sequentially.

Example:

net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
    net.add(nn.Dense(10, activation='relu'))
    net.add(nn.Dense(20))
net.hybridize()

Methods

add(*blocks)

Adds block on top of the stack.

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

add(*blocks)[source]

Adds block on top of the stack.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.InstanceNorm(axis=1, epsilon=1e-05, center=True, scale=False, beta_initializer='zeros', gamma_initializer='ones', in_channels=0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Applies instance normalization to the n-dimensional input array. This operator takes an n-dimensional input array where (n>2) and normalizes the input using the following formula:

\[ \begin{align}\begin{aligned}\bar{C} = \{i \mid i \neq 0, i \neq axis\}\\out = \frac{x - mean[data, \bar{C}]}{ \sqrt{Var[data, \bar{C}]} + \epsilon} * gamma + beta\end{aligned}\end{align} \]

Methods

hybrid_forward(F, x, gamma, beta)

Overrides to construct symbolic graph for this Block.

Parameters
  • axis (int, default 1) – The axis that will be excluded in the normalization process. This is typically the channels (C) axis. For instance, after a Conv2D layer with layout=’NCHW’, set axis=1 in InstanceNorm. If layout=’NHWC’, then set axis=3. Data will be normalized along axes excluding the first axis and the axis given.

  • epsilon (float, default 1e-5) – Small float added to variance to avoid dividing by zero.

  • center (bool, default True) – If True, add offset of beta to normalized tensor. If False, beta is ignored.

  • scale (bool, default True) – If True, multiply by gamma. If False, gamma is not used. When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

  • beta_initializer (str or Initializer, default ‘zeros’) – Initializer for the beta weight.

  • gamma_initializer (str or Initializer, default ‘ones’) – Initializer for the gamma weight.

  • in_channels (int, default 0) – Number of channels (feature maps) in input data. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

References

Instance Normalization: The Missing Ingredient for Fast Stylization

Examples

>>> # Input of shape (2,1,2)
>>> x = mx.nd.array([[[ 1.1,  2.2]],
...                 [[ 3.3,  4.4]]])
>>> # Instance normalization is calculated with the above formula
>>> layer = InstanceNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[[-0.99998355  0.99998331]]
 [[-0.99998319  0.99998361]]]
<NDArray 2x1x2 @cpu(0)>
hybrid_forward(F, x, gamma, beta)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Lambda(function, prefix=None)[source]

Bases: mxnet.gluon.block.Block

Wraps an operator or an expression as a Block object.

Parameters
  • function (str or function) –

    Function used in lambda must be one of the following: 1) the name of an operator that is available in ndarray. For example:

    block = Lambda('tanh')
    
    1. a function that conforms to def function(*args). For example:

      block = Lambda(lambda x: nd.LeakyReLU(x, slope=0.1))
      

  • Inputs

    • ** args *: one or more input data. Their shapes depend on the function.

  • Output

    • ** outputs *: one or more output data. Their shapes depend on the function.

Methods

forward(*args)

Overrides to implement forward computation using NDArray.

forward(*args)[source]

Overrides to implement forward computation using NDArray. Only accepts positional arguments.

Parameters

*args (list of NDArray) – Input tensors.

class mxnet.gluon.nn.LayerNorm(axis=-1, epsilon=1e-05, center=True, scale=True, beta_initializer='zeros', gamma_initializer='ones', in_channels=0, prefix=None, params=None)[source]

Bases: mxnet.gluon.block.HybridBlock

Applies layer normalization to the n-dimensional input array. This operator takes an n-dimensional input array and normalizes the input using the given axis:

\[out = \frac{x - mean[data, axis]}{ \sqrt{Var[data, axis] + \epsilon}} * gamma + beta\]

Methods

hybrid_forward(F, data, gamma, beta)

Overrides to construct symbolic graph for this Block.

Parameters
  • axis (int, default -1) – The axis that should be normalized. This is typically the axis of the channels.

  • epsilon (float, default 1e-5) – Small float added to variance to avoid dividing by zero.

  • center (bool, default True) – If True, add offset of beta to normalized tensor. If False, beta is ignored.

  • scale (bool, default True) – If True, multiply by gamma. If False, gamma is not used.

  • beta_initializer (str or Initializer, default ‘zeros’) – Initializer for the beta weight.

  • gamma_initializer (str or Initializer, default ‘ones’) – Initializer for the gamma weight.

  • in_channels (int, default 0) – Number of channels (feature maps) in input data. If not specified, initialization will be deferred to the first time forward is called and in_channels will be inferred from the shape of input data.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

References

Layer Normalization

Examples

>>> # Input of shape (2, 5)
>>> x = mx.nd.array([[1, 2, 3, 4, 5], [1, 1, 2, 2, 2]])
>>> # Layer normalization is calculated with the above formula
>>> layer = LayerNorm()
>>> layer.initialize(ctx=mx.cpu(0))
>>> layer(x)
[[-1.41421    -0.707105    0.          0.707105    1.41421   ]
 [-1.2247195  -1.2247195   0.81647956  0.81647956  0.81647956]]
<NDArray 2x5 @cpu(0)>
hybrid_forward(F, data, gamma, beta)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.LeakyReLU(alpha, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Leaky version of a Rectified Linear Unit.

It allows a small gradient when the unit is not active

\[\begin{split}f\left(x\right) = \left\{ \begin{array}{lr} \alpha x & : x \lt 0 \\ x & : x \geq 0 \\ \end{array} \right.\\\end{split}\]

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Parameters

alpha (float) – slope coefficient for the negative half axis. Must be >= 0.

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.MaxPool1D(pool_size=2, strides=None, padding=0, layout='NCW', ceil_mode=False, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Max pooling operation for one dimensional data.

Parameters
  • pool_size (int) – Size of the max pooling windows.

  • strides (int, or None) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCW') – Dimension ordering of data and out (‘NCW’ or ‘NWC’). ‘N’, ‘C’, ‘W’ stands for batch, channel, and width (time) dimensions respectively. Pooling is applied on the W dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

Inputs:
  • data: 3D input tensor with shape (batch_size, in_channels, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 3D output tensor with shape (batch_size, channels, out_width) when layout is NCW. out_width is calculated as:

    out_width = floor((width+2*padding-pool_size)/strides)+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.MaxPool2D(pool_size=(2, 2), strides=None, padding=0, layout='NCHW', ceil_mode=False, **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Max pooling operation for two dimensional (spatial) data.

Parameters
  • pool_size (int or list/tuple of 2 ints,) – Size of the max pooling windows.

  • strides (int, list/tuple of 2 ints, or None.) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int or list/tuple of 2 ints,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCHW') – Dimension ordering of data and out (‘NCHW’ or ‘NHWC’). ‘N’, ‘C’, ‘H’, ‘W’ stands for batch, channel, height, and width dimensions respectively. padding is applied on ‘H’ and ‘W’ dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

Inputs:
  • data: 4D input tensor with shape (batch_size, in_channels, height, width) when layout is NCHW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 4D output tensor with shape (batch_size, channels, out_height, out_width) when layout is NCHW. out_height and out_width are calculated as:

    out_height = floor((height+2*padding[0]-pool_size[0])/strides[0])+1
    out_width = floor((width+2*padding[1]-pool_size[1])/strides[1])+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.MaxPool3D(pool_size=(2, 2, 2), strides=None, padding=0, ceil_mode=False, layout='NCDHW', **kwargs)[source]

Bases: mxnet.gluon.nn.conv_layers._Pooling

Max pooling operation for 3D data (spatial or spatio-temporal).

Parameters
  • pool_size (int or list/tuple of 3 ints,) – Size of the max pooling windows.

  • strides (int, list/tuple of 3 ints, or None.) – Factor by which to downscale. E.g. 2 will halve the input size. If None, it will default to pool_size.

  • padding (int or list/tuple of 3 ints,) – If padding is non-zero, then the input is implicitly zero-padded on both sides for padding number of points.

  • layout (str, default 'NCDHW') – Dimension ordering of data and out (‘NCDHW’ or ‘NDHWC’). ‘N’, ‘C’, ‘H’, ‘W’, ‘D’ stands for batch, channel, height, width and depth dimensions respectively. padding is applied on ‘D’, ‘H’ and ‘W’ dimension.

  • ceil_mode (bool, default False) – When True, will use ceil instead of floor to compute the output shape.

Inputs:
  • data: 5D input tensor with shape (batch_size, in_channels, depth, height, width) when layout is NCW. For other layouts shape is permuted accordingly.

Outputs:
  • out: 5D output tensor with shape (batch_size, channels, out_depth, out_height, out_width) when layout is NCDHW. out_depth, out_height and out_width are calculated as:

    out_depth = floor((depth+2*padding[0]-pool_size[0])/strides[0])+1
    out_height = floor((height+2*padding[1]-pool_size[1])/strides[1])+1
    out_width = floor((width+2*padding[2]-pool_size[2])/strides[2])+1
    

    When ceil_mode is True, ceil will be used instead of floor in this equation.

class mxnet.gluon.nn.PReLU(alpha_initializer=<mxnet.initializer.Constant object>, in_channels=1, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Parametric leaky version of a Rectified Linear Unit. <https://arxiv.org/abs/1502.01852>`_ paper.

It learns a gradient when the unit is not active

\[\begin{split}f\left(x\right) = \left\{ \begin{array}{lr} \alpha x & : x \lt 0 \\ x & : x \geq 0 \\ \end{array} \right.\\\end{split}\]

Methods

hybrid_forward(F, x, alpha)

Overrides to construct symbolic graph for this Block.

where alpha is a learned parameter.

Parameters
  • alpha_initializer (Initializer) – Initializer for the embeddings matrix.

  • in_channels (int, default 1) – Number of channels (alpha parameters) to learn. Can either be 1 or n where n is the size of the second dimension of the input tensor.

  • Inputs

    • data: input tensor with arbitrary shape.

  • Outputs

    • out: output tensor with the same shape as data.

hybrid_forward(F, x, alpha)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.ReflectionPad2D(padding=0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Pads the input tensor using the reflection of the input boundary.

Parameters

padding (int) – An integer padding size

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Inputs:
  • data: input tensor with the shape \((N, C, H_{in}, W_{in})\).

Outputs:
  • out: output tensor with the shape \((N, C, H_{out}, W_{out})\), where

    \[ \begin{align}\begin{aligned}H_{out} = H_{in} + 2 \cdot padding\\W_{out} = W_{in} + 2 \cdot padding\end{aligned}\end{align} \]

Examples

>>> m = nn.ReflectionPad2D(3)
>>> input = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> output = m(input)
hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.SELU(**kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Scaled Exponential Linear Unit (SELU)

“Self-Normalizing Neural Networks”, Klambauer et al, 2017 https://arxiv.org/abs/1706.02515

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.Sequential(prefix=None, params=None)[source]

Bases: mxnet.gluon.block.Block

Stacks Blocks sequentially.

Example:

net = nn.Sequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
    net.add(nn.Dense(10, activation='relu'))
    net.add(nn.Dense(20))

Methods

add(*blocks)

Adds block on top of the stack.

forward(x)

Overrides to implement forward computation using NDArray.

hybridize([active])

Activates or deactivates HybridBlock s recursively.

add(*blocks)[source]

Adds block on top of the stack.

forward(x)[source]

Overrides to implement forward computation using NDArray. Only accepts positional arguments.

Parameters

*args (list of NDArray) – Input tensors.

hybridize(active=True, **kwargs)[source]

Activates or deactivates HybridBlock s recursively. Has no effect on non-hybrid children.

Parameters
  • active (bool, default True) – Whether to turn hybrid on or off.

  • **kwargs (string) – Additional flags for hybridized operator.

class mxnet.gluon.nn.Swish(beta=1.0, **kwargs)[source]

Bases: mxnet.gluon.block.HybridBlock

Swish Activation function

https://arxiv.org/pdf/1710.05941.pdf

Methods

hybrid_forward(F, x)

Overrides to construct symbolic graph for this Block.

Parameters

beta (float) – swish(x) = x * sigmoid(beta*x)

Inputs:
  • data: input tensor with arbitrary shape.

Outputs:
  • out: output tensor with the same shape as data.

hybrid_forward(F, x)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

class mxnet.gluon.nn.SymbolBlock(outputs, inputs, params=None)[source]

Bases: mxnet.gluon.block.HybridBlock

Construct block from symbol. This is useful for using pre-trained models as feature extractors. For example, you may want to extract the output from fc2 layer in AlexNet.

Parameters
  • outputs (Symbol or list of Symbol) – The desired output for SymbolBlock.

  • inputs (Symbol or list of Symbol) – The Variables in output’s argument that should be used as inputs.

  • params (ParameterDict) – Parameter dictionary for arguments and auxililary states of outputs that are not inputs.

Methods

cast(dtype)

Cast this Block to use another data type.

forward(x, *args)

Defines the forward computation.

hybrid_forward(F, x, *args, **kwargs)

Overrides to construct symbolic graph for this Block.

imports(symbol_file, input_names[, …])

Import model previously saved by gluon.HybridBlock.export or Module.save_checkpoint as a gluon.SymbolBlock for use in Gluon.

Examples

>>> # To extract the feature from fc1 and fc2 layers of AlexNet:
>>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(),
                                             prefix='model_')
>>> inputs = mx.sym.var('data')
>>> out = alexnet(inputs)
>>> internals = out.get_internals()
>>> print(internals.list_outputs())
['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
>>> outputs = [internals['model_dense0_relu_fwd_output'],
               internals['model_dense1_relu_fwd_output']]
>>> # Create SymbolBlock that shares parameters with alexnet
>>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> print(feat_model(x))
cast(dtype)[source]

Cast this Block to use another data type.

Parameters

dtype (str or numpy.dtype) – The new data type.

forward(x, *args)[source]

Defines the forward computation. Arguments can be either NDArray or Symbol.

hybrid_forward(F, x, *args, **kwargs)[source]

Overrides to construct symbolic graph for this Block.

Parameters
  • x (Symbol or NDArray) – The first input tensor.

  • *args (list of Symbol or list of NDArray) – Additional input tensors.

static imports(symbol_file, input_names, param_file=None, ctx=None)[source]

Import model previously saved by gluon.HybridBlock.export or Module.save_checkpoint as a gluon.SymbolBlock for use in Gluon.

Parameters
  • symbol_file (str) – Path to symbol file.

  • input_names (list of str) – List of input variable names

  • param_file (str, optional) – Path to parameter file.

  • ctx (Context, default None) – The context to initialize gluon.SymbolBlock on.

Returns

gluon.SymbolBlock loaded from symbol and parameter files.

Return type

gluon.SymbolBlock

Examples

>>> net1 = gluon.model_zoo.vision.resnet18_v1(
...     prefix='resnet', pretrained=True)
>>> net1.hybridize()
>>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
>>> out1 = net1(x)
>>> net1.export('net1', epoch=1)
>>>
>>> net2 = gluon.SymbolBlock.imports(
...     'net1-symbol.json', ['data'], 'net1-0001.params')
>>> out2 = net2(x)