Gluon Package

Overview

Gluon package is a high-level interface for MXNet designed to be easy to use while keeping most of the flexibility of low level API. Gluon supports both imperative and symbolic programming, making it easy to train complex models imperatively in Python and then deploy with symbolic graph in C++ and Scala.

Parameter

Parameter A Container holding parameters (weights) of Blocks.
Constant A constant parameter for holding immutable tensors.
ParameterDict A dictionary managing a set of parameters.

Containers

Block Base class for all neural network layers and models.
HybridBlock HybridBlock supports forwarding with both Symbol and NDArray.
SymbolBlock Construct block from symbol.
nn.Sequential Stacks Blocks sequentially.
nn.HybridSequential Stacks HybridBlocks sequentially.

Trainer

Trainer Applies an Optimizer on a set of Parameters.

Utilities

split_data Splits an NDArray into num_slice slices along batch_axis.
split_and_load Splits an NDArray into len(ctx_list) slices along batch_axis and loads each slice to one context in ctx_list.
clip_global_norm Rescales NDArrays so that the sum of their 2-norm is smaller than max_norm.

API Reference

Neural network module.

class mxnet.gluon.Block(prefix=None, params=None)[source]

Base class for all neural network layers and models. Your models should subclass this class.

Block can be nested recursively in a tree structure. You can create and assign child Block as regular attributes:

from mxnet.gluon import Block, nn
from mxnet import ndarray as F

class Model(Block):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        # use name_scope to give child Blocks appropriate names.
        with self.name_scope():
            self.dense0 = nn.Dense(20)
            self.dense1 = nn.Dense(20)

    def forward(self, x):
        x = F.relu(self.dense0(x))
        return F.relu(self.dense1(x))

model = Model()
model.initialize(ctx=mx.cpu(0))
model(F.zeros((10, 10), ctx=mx.cpu(0)))

Child Block assigned this way will be registered and collect_params() will collect their Parameters recursively. You can also manually register child blocks with register_child().

Parameters:
  • prefix (str) – Prefix acts like a name space. All children blocks created in parent block’s name_scope() will have parent block’s prefix in their name. Please refer to naming tutorial for more info on prefix and naming.
  • params (ParameterDict or None) –

    ParameterDict for sharing weights with the new Block. For example, if you want dense1 to share dense0‘s weights, you can do:

    dense0 = nn.Dense(20)
    dense1 = nn.Dense(20, params=dense0.collect_params())
    
__call__(*args)[source]

Calls forward. Only accepts positional arguments.

__setattr__(name, value)[source]

Registers parameters.

__weakref__

list of weak references to the object (if defined)

cast(dtype)[source]

Cast this Block to use another data type.

Parameters:dtype (str or numpy.dtype) – The new data type.
collect_params(select=None)[source]

Returns a ParameterDict containing this Block and all of its children’s Parameters(default), also can returns the select ParameterDict which match some given regular expressions.

For example, collect the specified parameters in [‘conv1_weight’, ‘conv1_bias’, ‘fc_weight’, ‘fc_bias’]:

model.collect_params('conv1_weight|conv1_bias|fc_weight|fc_bias')

or collect all parameters whose names end with ‘weight’ or ‘bias’, this can be done using regular expressions:

model.collect_params('.*weight|.*bias')
Parameters:select (str) – regular expressions
Returns:
Return type:The selected ParameterDict
forward(*args)[source]

Overrides to implement forward computation using NDArray. Only accepts positional arguments.

Parameters:*args (list of NDArray) – Input tensors.
hybridize(active=True, **kwargs)[source]

Activates or deactivates HybridBlock s recursively. Has no effect on non-hybrid children.

Parameters:
  • active (bool, default True) – Whether to turn hybrid on or off.
  • **kwargs (string) – Additional flags for hybridized operator.
initialize(init=, ctx=None, verbose=False, force_reinit=False)[source]

Initializes Parameter s of this Block and its children. Equivalent to block.collect_params().initialize(...)

Parameters:
  • init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.
  • ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).
  • verbose (bool, default False) – Whether to verbosely print out details on initialization.
  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.
load_parameters(filename, ctx=None, allow_missing=False, ignore_extra=False)[source]

Load parameters from file previously saved by save_parameters.

Parameters:
  • filename (str) – Path to parameter file.
  • ctx (Context or list of Context, default cpu()) – Context(s) to initialize loaded parameters on.
  • allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.
  • ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this Block.

References

Saving and Loading Gluon Models

load_params(filename, ctx=None, allow_missing=False, ignore_extra=False)[source]

[Deprecated] Please use load_parameters.

Load parameters from file.

filename : str
Path to parameter file.
ctx : Context or list of Context, default cpu()
Context(s) to initialize loaded parameters on.
allow_missing : bool, default False
Whether to silently skip loading parameters not represents in the file.
ignore_extra : bool, default False
Whether to silently ignore parameters from the file that are not present in this Block.
name

Name of this Block, without ‘_’ in the end.

name_scope()[source]

Returns a name space object managing a child Block and parameter names. Should be used within a with statement:

with self.name_scope():
    self.dense = nn.Dense(20)

Please refer to naming tutorial for more info on prefix and naming.

params

Returns this Block‘s parameter dictionary (does not include its children’s parameters).

prefix

Prefix of this Block.

register_child(block, name=None)[source]

Registers block as a child of self. Block s assigned to self as attributes will be registered automatically.

save_parameters(filename)[source]

Save parameters to file.

Saved parameters can only be loaded with load_parameters. Note that this method only saves parameters, not model structure. If you want to save model structures, please use HybridBlock.export().

Parameters:filename (str) – Path to file.

References

Saving and Loading Gluon Models

save_params(filename)[source]

[Deprecated] Please use save_parameters. Note that if you want to load from SymbolBlock later, please use export instead.

Save parameters to file.

filename : str
Path to file.
class mxnet.gluon.Constant(name, value)[source]

A constant parameter for holding immutable tensors. Constant`s are ignored by `autograd and Trainer, thus their values will not change during training. But you can still update their values manually with the set_data method.

`Constant`s can be created with either:

const = mx.gluon.Constant('const', [[1,2],[3,4]])

or:

class Block(gluon.Block):
    def __init__(self, **kwargs):
        super(Block, self).__init__(**kwargs)
        self.const = self.params.get_constant('const', [[1,2],[3,4]])
Parameters:
  • name (str) – Name of the parameter.
  • value (array-like) – Initial value for the constant.
exception mxnet.gluon.DeferredInitializationError[source]

Error for unfinished deferred initialization.

class mxnet.gluon.HybridBlock(prefix=None, params=None)[source]

HybridBlock supports forwarding with both Symbol and NDArray.

HybridBlock is similar to Block, with a few differences:

import mxnet as mx
from mxnet.gluon import HybridBlock, nn

class Model(HybridBlock):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        # use name_scope to give child Blocks appropriate names.
        with self.name_scope():
            self.dense0 = nn.Dense(20)
            self.dense1 = nn.Dense(20)

    def hybrid_forward(self, F, x):
        x = F.relu(self.dense0(x))
        return F.relu(self.dense1(x))

model = Model()
model.initialize(ctx=mx.cpu(0))
model.hybridize()
model(mx.nd.zeros((10, 10), ctx=mx.cpu(0)))

Forward computation in HybridBlock must be static to work with Symbol s, i.e. you cannot call NDArray.asnumpy(), NDArray.shape, NDArray.dtype, NDArray indexing (x[i]) etc on tensors. Also, you cannot use branching or loop logic that bases on non-constant expressions like random numbers or intermediate results, since they change the graph structure for each iteration.

Before activating with hybridize(), HybridBlock works just like normal Block. After activation, HybridBlock will create a symbolic graph representing the forward computation and cache it. On subsequent forwards, the cached graph will be used instead of hybrid_forward().

Please see references for detailed tutorial.

References

Hybrid - Faster training and easy deployment

__setattr__(name, value)[source]

Registers parameters.

export(path, epoch=0)[source]

Export HybridBlock to json format that can be loaded by SymbolBlock.imports, mxnet.mod.Module or the C++ interface.

Note

When there are only one input, it will have name data. When there Are more than one inputs, they will be named as data0, data1, etc.

Parameters:
  • path (str) – Path to save model. Two files path-symbol.json and path-xxxx.params will be created, where xxxx is the 4 digits epoch number.
  • epoch (int) – Epoch number of saved model.
forward(x, *args)[source]

Defines the forward computation. Arguments can be either NDArray or Symbol.

hybrid_forward(F, x, *args, **kwargs)[source]

Overrides to construct symbolic graph for this Block.

Parameters:
  • x (Symbol or NDArray) – The first input tensor.
  • *args (list of Symbol or list of NDArray) – Additional input tensors.
infer_shape(*args)[source]

Infers shape of Parameters from inputs.

infer_type(*args)[source]

Infers data type of Parameters from inputs.

class mxnet.gluon.Parameter(name, grad_req='write', shape=None, dtype=, lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, differentiable=True)[source]

A Container holding parameters (weights) of Blocks.

Parameter holds a copy of the parameter on each Context after it is initialized with Parameter.initialize(...). If grad_req is not 'null', it will also hold a gradient array on each Context:

ctx = mx.gpu(0)
x = mx.nd.zeros((16, 100), ctx=ctx)
w = mx.gluon.Parameter('fc_weight', shape=(64, 100), init=mx.init.Xavier())
b = mx.gluon.Parameter('fc_bias', shape=(64,), init=mx.init.Zero())
w.initialize(ctx=ctx)
b.initialize(ctx=ctx)
out = mx.nd.FullyConnected(x, w.data(ctx), b.data(ctx), num_hidden=64)
Parameters:
  • name (str) – Name of this parameter.
  • grad_req ({'write', 'add', 'null'}, default 'write') –

    Specifies how to update gradient to grad arrays.

    • 'write' means everytime gradient is written to grad NDArray.
    • 'add' means everytime gradient is added to the grad NDArray. You need to manually call zero_grad() to clear the gradient buffer before each iteration when using this option.
    • ‘null’ means gradient is not requested for this parameter. gradient arrays will not be allocated.
  • shape (tuple of int, default None) – Shape of this parameter. By default shape is not specified. Parameter with unknown shape can be used for Symbol API, but init will throw an error when using NDArray API.
  • dtype (numpy.dtype or str, default 'float32') – Data type of this parameter. For example, numpy.float32 or 'float32'.
  • lr_mult (float, default 1.0) – Learning rate multiplier. Learning rate will be multiplied by lr_mult when updating this parameter with optimizer.
  • wd_mult (float, default 1.0) – Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
  • init (Initializer, default None) – Initializer of this parameter. Will use the global initializer by default.
grad_req

{‘write’, ‘add’, ‘null’} – This can be set before or after initialization. Setting grad_req to 'null' with x.grad_req = 'null' saves memory and computation when you don’t need gradient w.r.t x.

lr_mult

float – Local learning rate multiplier for this Parameter. The actual learning rate is calculated with learning_rate * lr_mult. You can set it with param.lr_mult = 2.0

wd_mult

float – Local weight decay multiplier for this Parameter.

__weakref__

list of weak references to the object (if defined)

cast(dtype)[source]

Cast data and gradient of this Parameter to a new data type.

Parameters:dtype (str or numpy.dtype) – The new data type.
data(ctx=None)[source]

Returns a copy of this parameter on one context. Must have been initialized on this context before.

Parameters:ctx (Context) – Desired context.
Returns:
Return type:NDArray on ctx
grad(ctx=None)[source]

Returns a gradient buffer for this parameter on one context.

Parameters:ctx (Context) – Desired context.
initialize(init=None, ctx=None, default_init=, force_reinit=False)[source]

Initializes parameter and gradient arrays. Only used for NDArray API.

Parameters:
  • init (Initializer) – The initializer to use. Overrides Parameter.init() and default_init.
  • ctx (Context or list of Context, defaults to context.current_context().) –

    Initialize Parameter on given context. If ctx is a list of Context, a copy will be made for each context.

    Note

    Copies are independent arrays. User is responsible for keeping their values consistent when updating. Normally gluon.Trainer does this for you.

  • default_init (Initializer) – Default initializer is used when both init() and Parameter.init() are None.
  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.

Examples

>>> weight = mx.gluon.Parameter('weight', shape=(2, 2))
>>> weight.initialize(ctx=mx.cpu(0))
>>> weight.data()
[[-0.01068833  0.01729892]
 [ 0.02042518 -0.01618656]]

>>> weight.grad()
[[ 0.  0.]
 [ 0.  0.]]

>>> weight.initialize(ctx=[mx.gpu(0), mx.gpu(1)])
>>> weight.data(mx.gpu(0))
[[-0.00873779 -0.02834515]
 [ 0.05484822 -0.06206018]]

>>> weight.data(mx.gpu(1))
[[-0.00873779 -0.02834515]
 [ 0.05484822 -0.06206018]]

list_ctx()[source]

Returns a list of contexts this parameter is initialized on.

list_data()[source]

Returns copies of this parameter on all contexts, in the same order as creation.

list_grad()[source]

Returns gradient buffers on all contexts, in the same order as values().

reset_ctx(ctx)[source]

Re-assign Parameter to other contexts.

Parameters:ctx (Context or list of Context, default context.current_context().) – Assign Parameter to given context. If ctx is a list of Context, a copy will be made for each context.
set_data(data)[source]

Sets this parameter’s value on all contexts.

var()[source]

Returns a symbol representing this parameter.

zero_grad()[source]

Sets gradient buffer on all contexts to 0. No action is taken if parameter is uninitialized or doesn’t require gradient.

class mxnet.gluon.ParameterDict(prefix='', shared=None)[source]

A dictionary managing a set of parameters.

Parameters:
  • prefix (str, default '') – The prefix to be prepended to all Parameters’ names created by this dict.
  • shared (ParameterDict or None) – If not None, when this dict’s get() method creates a new parameter, will first try to retrieve it from “shared” dict. Usually used for sharing parameters with another Block.
__weakref__

list of weak references to the object (if defined)

get(name, **kwargs)[source]

Retrieves a Parameter with name self.prefix+name. If not found, get() will first try to retrieve it from “shared” dict. If still not found, get() will create a new Parameter with key-word arguments and insert it to self.

Parameters:
  • name (str) – Name of the desired Parameter. It will be prepended with this dictionary’s prefix.
  • **kwargs (dict) – The rest of key-word arguments for the created Parameter.
Returns:

The created or retrieved Parameter.

Return type:

Parameter

get_constant(name, value=None)[source]

Retrieves a Constant with name self.prefix+name. If not found, get() will first try to retrieve it from “shared” dict. If still not found, get() will create a new Constant with key-word arguments and insert it to self.

Parameters:
  • name (str) – Name of the desired Constant. It will be prepended with this dictionary’s prefix.
  • value (array-like) – Initial value of constant.
Returns:

The created or retrieved Constant.

Return type:

Constant

initialize(init=, ctx=None, verbose=False, force_reinit=False)[source]

Initializes all Parameters managed by this dictionary to be used for NDArray API. It has no effect when using Symbol API.

Parameters:
  • init (Initializer) – Global default Initializer to be used when Parameter.init() is None. Otherwise, Parameter.init() takes precedence.
  • ctx (Context or list of Context) – Keeps a copy of Parameters on one or many context(s).
  • verbose (bool, default False) – Whether to verbosely print out details on initialization.
  • force_reinit (bool, default False) – Whether to force re-initialization if parameter is already initialized.
load(filename, ctx=None, allow_missing=False, ignore_extra=False, restore_prefix='')[source]

Load parameters from file.

Parameters:
  • filename (str) – Path to parameter file.
  • ctx (Context or list of Context) – Context(s) initialize loaded parameters on.
  • allow_missing (bool, default False) – Whether to silently skip loading parameters not represents in the file.
  • ignore_extra (bool, default False) – Whether to silently ignore parameters from the file that are not present in this ParameterDict.
  • restore_prefix (str, default '') – prepend prefix to names of stored parameters before loading.
prefix

Prefix of this dict. It will be prepended to Parameter`s' name created with :py:func:`get.

reset_ctx(ctx)[source]

Re-assign all Parameters to other contexts.

Parameters:ctx (Context or list of Context, default context.current_context().) – Assign Parameter to given context. If ctx is a list of Context, a copy will be made for each context.
save(filename, strip_prefix='')[source]

Save parameters to file.

Parameters:
  • filename (str) – Path to parameter file.
  • strip_prefix (str, default '') – Strip prefix from parameter names before saving.
setattr(name, value)[source]

Set an attribute to a new value for all Parameters.

For example, set grad_req to null if you don’t need gradient w.r.t a model’s Parameters:

model.collect_params().setattr('grad_req', 'null')

or change the learning rate multiplier:

model.collect_params().setattr('lr_mult', 0.5)
Parameters:
  • name (str) – Name of the attribute.
  • value (valid type for attribute name) – The new value for the attribute.
update(other)[source]

Copies all Parameters in other to self.

zero_grad()[source]

Sets all Parameters’ gradient buffer to 0.

class mxnet.gluon.SymbolBlock(outputs, inputs, params=None)[source]

Construct block from symbol. This is useful for using pre-trained models as feature extractors. For example, you may want to extract the output from fc2 layer in AlexNet.

Parameters:
  • outputs (Symbol or list of Symbol) – The desired output for SymbolBlock.
  • inputs (Symbol or list of Symbol) – The Variables in output’s argument that should be used as inputs.
  • params (ParameterDict) – Parameter dictionary for arguments and auxililary states of outputs that are not inputs.

Examples

>>> # To extract the feature from fc1 and fc2 layers of AlexNet:
>>> alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(),
                                             prefix='model_')
>>> inputs = mx.sym.var('data')
>>> out = alexnet(inputs)
>>> internals = out.get_internals()
>>> print(internals.list_outputs())
['data', ..., 'model_dense0_relu_fwd_output', ..., 'model_dense1_relu_fwd_output', ...]
>>> outputs = [internals['model_dense0_relu_fwd_output'],
               internals['model_dense1_relu_fwd_output']]
>>> # Create SymbolBlock that shares parameters with alexnet
>>> feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
>>> print(feat_model(x))
static imports(symbol_file, input_names, param_file=None, ctx=None)[source]

Import model previously saved by HybridBlock.export or Module.save_checkpoint as a SymbolBlock for use in Gluon.

Parameters:
  • symbol_file (str) – Path to symbol file.
  • input_names (list of str) – List of input variable names
  • param_file (str, optional) – Path to parameter file.
  • ctx (Context, default None) – The context to initialize SymbolBlock on.
Returns:

SymbolBlock loaded from symbol and parameter files.

Return type:

SymbolBlock

Examples

>>> net1 = gluon.model_zoo.vision.resnet18_v1(
...     prefix='resnet', pretrained=True)
>>> net1.hybridize()
>>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
>>> out1 = net1(x)
>>> net1.export('net1', epoch=1)
>>>
>>> net2 = gluon.SymbolBlock.imports(
...     'net1-symbol.json', ['data'], 'net1-0001.params')
>>> out2 = net2(x)
class mxnet.gluon.Trainer(params, optimizer, optimizer_params=None, kvstore='device', compression_params=None)[source]

Applies an Optimizer on a set of Parameters. Trainer should be used together with autograd.

Parameters:
  • params (ParameterDict) – The set of parameters to optimize.
  • optimizer (str or Optimizer) – The optimizer to use. See help on Optimizer for a list of available optimizers.
  • optimizer_params (dict) – Key-word arguments to be passed to optimizer constructor. For example, {‘learning_rate’: 0.1}. All optimizers accept learning_rate, wd (weight decay), clip_gradient, and lr_scheduler. See each optimizer’s constructor for a list of additional supported arguments.
  • kvstore (str or KVStore) – kvstore type for multi-gpu and distributed training. See help on mxnet.kvstore.create for more information.
  • compression_params (dict) – Specifies type of gradient compression and additional arguments depending on the type of compression being used. For example, 2bit compression requires a threshold. Arguments would then be {‘type’:‘2bit’, ‘threshold’:0.5} See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
  • Properties
  • ----------
  • learning_rate (float) – The current learning rate of the optimizer. Given an Optimizer object optimizer, its learning rate can be accessed as optimizer.learning_rate.
__weakref__

list of weak references to the object (if defined)

load_states(fname)[source]

Loads trainer states (e.g. optimizer, momentum) from a file.

Parameters:fname (str) – Path to input states file.
save_states(fname)[source]

Saves trainer states (e.g. optimizer, momentum) to a file.

Parameters:fname (str) – Path to output states file.
set_learning_rate(lr)[source]

Sets a new learning rate of the optimizer.

Parameters:lr (float) – The new learning rate of the optimizer.
step(batch_size, ignore_stale_grad=False)[source]

Makes one step of parameter update. Should be called after autograd.compute_gradient and outside of record() scope.

Parameters:
  • batch_size (int) – Batch size of data processed. Gradient will be normalized by 1/batch_size. Set this to 1 if you normalized loss manually with loss = mean(loss).
  • ignore_stale_grad (bool, optional, default=False) – If true, ignores Parameters with stale gradient (gradient that has not been updated by backward after last step) and skip update.
class mxnet.gluon.nn.Sequential(prefix=None, params=None)[source]

Stacks Blocks sequentially.

Example:

net = nn.Sequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
    net.add(nn.Dense(10, activation='relu'))
    net.add(nn.Dense(20))
add(*blocks)[source]

Adds block on top of the stack.

hybridize(active=True, **kwargs)[source]

Activates or deactivates `HybridBlock`s recursively. Has no effect on non-hybrid children.

Parameters:
  • active (bool, default True) – Whether to turn hybrid on or off.
  • **kwargs (string) – Additional flags for hybridized operator.
class mxnet.gluon.nn.HybridSequential(prefix=None, params=None)[source]

Stacks HybridBlocks sequentially.

Example:

net = nn.HybridSequential()
# use net's name_scope to give child Blocks appropriate names.
with net.name_scope():
    net.add(nn.Dense(10, activation='relu'))
    net.add(nn.Dense(20))
net.hybridize()
add(*blocks)[source]

Adds block on top of the stack.

Parallelization utility optimizer.

mxnet.gluon.utils.split_data(data, num_slice, batch_axis=0, even_split=True)[source]

Splits an NDArray into num_slice slices along batch_axis. Usually used for data parallelism where each slices is sent to one device (i.e. GPU).

Parameters:
  • data (NDArray) – A batch of data.
  • num_slice (int) – Number of desired slices.
  • batch_axis (int, default 0) – The axis along which to slice.
  • even_split (bool, default True) – Whether to force all slices to have the same number of elements. If True, an error will be raised when num_slice does not evenly divide data.shape[batch_axis].
Returns:

Return value is a list even if num_slice is 1.

Return type:

list of NDArray

mxnet.gluon.utils.split_and_load(data, ctx_list, batch_axis=0, even_split=True)[source]

Splits an NDArray into len(ctx_list) slices along batch_axis and loads each slice to one context in ctx_list.

Parameters:
  • data (NDArray) – A batch of data.
  • ctx_list (list of Context) – A list of Contexts.
  • batch_axis (int, default 0) – The axis along which to slice.
  • even_split (bool, default True) – Whether to force all slices to have the same number of elements.
Returns:

Each corresponds to a context in ctx_list.

Return type:

list of NDArray

mxnet.gluon.utils.clip_global_norm(arrays, max_norm)[source]

Rescales NDArrays so that the sum of their 2-norm is smaller than max_norm.

mxnet.gluon.utils.check_sha1(filename, sha1_hash)[source]

Check whether the sha1 hash of the file content matches the expected hash.

Parameters:
  • filename (str) – Path to the file.
  • sha1_hash (str) – Expected sha1 hash in hexadecimal digits.
Returns:

Whether the file content matches the expected hash.

Return type:

bool

mxnet.gluon.utils.download(url, path=None, overwrite=False, sha1_hash=None)[source]

Download an given URL

Parameters:
  • url (str) – URL to download
  • path (str, optional) – Destination path to store downloaded file. By default stores to the current directory with same name as in url.
  • overwrite (bool, optional) – Whether to overwrite destination file if already exists.
  • sha1_hash (str, optional) – Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified but doesn’t match.
Returns:

The file path of the downloaded file.

Return type:

str