gluon.contrib

This document lists the contrib APIs in Gluon:

mxnet.gluon.contrib

Contrib neural network module.

The Gluon Contrib API, defined in the gluon.contrib package, provides many useful experimental APIs for new features. This is a place for the community to try out the new features, so that feature contributors can receive feedback.

Warning

This package contains experimental APIs and may change in the near future.

In the rest of this document, we list routines provided by the gluon.contrib package.

Vision Data

data.vision.create_image_augment

Creates an augmenter block.

data.vision.ImageDataLoader

Image data loader with a large number of augmentation choices.

data.vision.ImageBboxDataLoader

Image iterator with a large number of augmentation choices for detection.

data.vision.ImageBboxRandomFlipLeftRight

Randomly flip the input image and bbox left to right with a probability of p(0.5 by default).

data.vision.ImageBboxCrop

Crops the image src and bbox to the given crop.

data.vision.ImageBboxRandomCropWithConstraints

Crop an image randomly with bounding box constraints.

data.vision.ImageBboxResize

Apply resize to image and bounding boxes.

Estimator

Estimator

Estimator Class for easy model training

Event Handler

StoppingHandler

Stop conditions to stop training Stop training if maximum number of batches or epochs reached.

MetricHandler

Metric Handler that update metric values at batch end

ValidationHandler

Validation Handler that evaluate model on validation dataset

LoggingHandler

Basic Logging Handler that applies to every Gluon estimator by default.

CheckpointHandler

Save the model after user define period

EarlyStoppingHandler

Early stop training if monitored value is not improving

API Reference

Contrib neural network module.

Gluon Estimator Module

class BatchProcessor[source]

Bases: object

BatchProcessor Class for plug and play fit_batch & evaluate_batch

During training or validation, data are divided into minibatches for processing. This class aims at providing hooks of training or validating on a minibatch of data. Users may provide customized fit_batch() and evaluate_batch() methods by inheriting from this class and overriding class methods.

BatchProcessor can be used to replace fit_batch() and evaluate_batch() in the base estimator class

evaluate_batch(estimator, val_batch, batch_axis=0)[source]

Evaluate the estimator model on a batch of validation data.

Parameters
  • estimator (Estimator) – Reference to the estimator

  • val_batch (tuple) – Data and label of a batch from the validation data loader.

  • batch_axis (int, default 0) – Batch axis to split the validation data into devices.

fit_batch(estimator, train_batch, batch_axis=0)[source]

Trains the estimator model on a batch of training data.

Parameters
  • estimator (Estimator) – Reference to the estimator

  • train_batch (tuple) – Data and label of a batch from the training data loader.

  • batch_axis (int, default 0) – Batch axis to split the training data into devices.

Returns

  • data (List of NDArray) – Sharded data from the batch. Data is sharded with gluon.split_and_load.

  • label (List of NDArray) – Sharded label from the batch. Labels are sharded with gluon.split_and_load.

  • pred (List of NDArray) – Prediction on each of the sharded inputs.

  • loss (List of NDArray) – Loss on each of the sharded inputs.

class CheckpointHandler(model_dir, model_prefix='model', monitor=None, verbose=0, save_best=False, mode='auto', epoch_period=1, batch_period=None, max_checkpoints=5, resume_from_checkpoint=False)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.TrainBegin, mxnet.gluon.contrib.estimator.event_handler.BatchEnd, mxnet.gluon.contrib.estimator.event_handler.EpochEnd

Save the model after user define period

CheckpointHandler saves the network architecture after first batch if the model can be fully hybridized, saves model parameters and trainer states after user defined period, default saves every epoch.

Parameters
  • model_dir (str) – File directory to save all the model related files including model architecture, model parameters, and trainer states.

  • model_prefix (str default 'model') – Prefix to add for all checkpoint file names.

  • monitor (EvalMetric, default None) – The metrics to monitor and determine if model has improved

  • verbose (int, default 0) – Verbosity mode, 1 means inform user every time a checkpoint is saved

  • save_best (bool, default False) – If True, monitor must not be None, CheckpointHandler will save the model parameters and trainer states with the best monitored value.

  • mode (str, default 'auto') – One of {auto, min, max}, if save_best=True, the comparison to make and determine if the monitored value has improved. if ‘auto’ mode, CheckpointHandler will try to use min or max based on the monitored metric name.

  • epoch_period (int, default 1) – Epoch intervals between saving the network. By default, checkpoints are saved every epoch.

  • batch_period (int, default None) – Batch intervals between saving the network. By default, checkpoints are not saved based on the number of batches.

  • max_checkpoints (int, default 5) – Maximum number of checkpoint files to keep in the model_dir, older checkpoints will be removed. Best checkpoint file is not counted.

  • resume_from_checkpoint (bool, default False) – Whether to resume training from checkpoint in model_dir. If True and checkpoints found, CheckpointHandler will load net parameters and trainer states, and train the remaining of epochs and batches.

class EarlyStoppingHandler(monitor, min_delta=0, patience=0, mode='auto', baseline=None)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.TrainBegin, mxnet.gluon.contrib.estimator.event_handler.EpochEnd, mxnet.gluon.contrib.estimator.event_handler.TrainEnd

Early stop training if monitored value is not improving

Parameters
  • monitor (EvalMetric) – The metric to monitor, and stop training if this metric does not improve.

  • min_delta (float, default 0) – Minimal change in monitored value to be considered as an improvement.

  • patience (int, default 0) – Number of epochs to wait for improvement before terminate training.

  • mode (str, default 'auto') – One of {auto, min, max}, if save_best_only=True, the comparison to make and determine if the monitored value has improved. if ‘auto’ mode, checkpoint handler will try to use min or max based on the monitored metric name.

  • baseline (float) – Baseline value to compare the monitored value with.

class Estimator(net, loss, train_metrics=None, val_metrics=None, initializer=None, trainer=None, device=None, val_net=None, val_loss=None, batch_processor=None)[source]

Bases: object

Estimator Class for easy model training

Estimator can be used to facilitate the training & validation process

Parameters
  • net (gluon.Block) – The model used for training.

  • loss (gluon.loss.Loss) – Loss (objective) function to calculate during training.

  • train_metrics (EvalMetric or list of EvalMetric) – Training metrics for evaluating models on training dataset.

  • val_metrics (EvalMetric or list of EvalMetric) – Validation metrics for evaluating models on validation dataset.

  • initializer (Initializer) – Initializer to initialize the network.

  • trainer (Trainer) – Trainer to apply optimizer on network parameters.

  • device (Device or list of Device) – Device(s) to run the training on.

  • val_net (gluon.Block) –

    The model used for validation. The validation model does not necessarily belong to the same model class as the training model. But the two models typically share the same architecture. Therefore the validation model can reuse parameters of the training model.

    The code example of consruction of val_net sharing the same network parameters as the training net is given below:

    >>> net = _get_train_network()
    >>> val_net = _get_test_network()
    >>> val_net.share_parameters(net.collect_params())
    >>> net.initialize(device=device)
    >>> est = Estimator(net, loss, val_net=val_net)
    

    Proper namespace match is required for weight sharing between two networks. Most networks inheriting Block can share their parameters correctly. An exception is Sequential networks that Block scope must be specified for correct weight sharing. For the naming in mxnet Gluon API, please refer to the site (https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html) for future information.

  • val_loss (gluon.loss.loss) – Loss (objective) function to calculate during validation. If set val_loss None, it will use the same loss function as self.loss

  • batch_processor (BatchProcessor) – BatchProcessor provides customized fit_batch() and evaluate_batch() methods

evaluate(val_data, batch_axis=0, event_handlers=None)[source]

Evaluate model on validation data.

This function calls evaluate_batch() on each of the batches from the validation data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and override evaluate_batch().

Parameters
  • val_data (DataLoader) – Validation data loader with data and labels.

  • batch_axis (int, default 0) – Batch axis to split the validation data into devices.

  • event_handlers (EventHandler or list of EventHandler) – List of EventHandlers to apply during validation. Besides event handlers specified here, a default MetricHandler and a LoggingHandler will be added if not specified explicitly.

fit(train_data, val_data=None, epochs=None, event_handlers=None, batches=None, batch_axis=0)[source]

Trains the model with a given DataLoader for a specified number of epochs or batches. The batch size is inferred from the data loader’s batch_size.

This function calls fit_batch() on each of the batches from the training data loader. Thus, for custom use cases, it’s possible to inherit the estimator class and override fit_batch().

Parameters
  • train_data (DataLoader) – Training data loader with data and labels.

  • val_data (DataLoader, default None) – Validation data loader with data and labels.

  • epochs (int, default None) – Number of epochs to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).

  • event_handlers (EventHandler or list of EventHandler) – List of EventHandlers to apply during training. Besides the event handlers specified here, a StoppingHandler, LoggingHandler and MetricHandler will be added by default if not yet specified manually. If validation data is provided, a ValidationHandler is also added if not already specified.

  • batches (int, default None) – Number of batches to iterate on the training data. You can only specify one and only one type of iteration(epochs or batches).

  • batch_axis (int, default 0) – Batch axis to split the training data into devices.

logger = None

logging.Logger object associated with the Estimator.

The logger is used for all logs generated by this estimator and its handlers. A new logging.Logger is created during Estimator construction and configured to write all logs with level logging.INFO or higher to sys.stdout.

You can modify the logging settings using the standard Python methods. For example, to save logs to a file in addition to printing them to stdout output, you can attach a logging.FileHandler to the logger.

>>> est = Estimator(net, loss)
>>> import logging
>>> est.logger.addHandler(logging.FileHandler(filename))
class GradientUpdateHandler(priority=-2000)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.BatchEnd

Gradient Update Handler that apply gradients on network weights

GradientUpdateHandler takes the priority level. It updates weight parameters at the end of each batch

Parameters

priority (scalar, default -2000) – priority level of the gradient update handler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.

class LoggingHandler(log_interval='epoch', metrics=None, priority=inf)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.TrainBegin, mxnet.gluon.contrib.estimator.event_handler.TrainEnd, mxnet.gluon.contrib.estimator.event_handler.EpochBegin, mxnet.gluon.contrib.estimator.event_handler.EpochEnd, mxnet.gluon.contrib.estimator.event_handler.BatchBegin, mxnet.gluon.contrib.estimator.event_handler.BatchEnd

Basic Logging Handler that applies to every Gluon estimator by default.

LoggingHandler logs hyper-parameters, training statistics, and other useful information during training

Parameters
  • log_interval (int or str, default 'epoch') – Logging interval during training. log_interval=’epoch’: display metrics every epoch log_interval=integer k: display metrics every interval of k batches

  • metrics (list of EvalMetrics) – Metrics to be logged, logged at batch end, epoch end, train end.

  • priority (scalar, default np.Inf) – Priority level of the LoggingHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.

class MetricHandler(metrics, priority=-1000)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.EpochBegin, mxnet.gluon.contrib.estimator.event_handler.BatchEnd

Metric Handler that update metric values at batch end

MetricHandler takes model predictions and true labels and update the metrics, it also update metric wrapper for loss with loss values. Validation loss and metrics will be handled by ValidationHandler

Parameters
  • metrics (List of EvalMetrics) – Metrics to be updated at batch end.

  • priority (scalar) – Priority level of the MetricHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.

class StoppingHandler(max_epoch=None, max_batch=None)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.TrainBegin, mxnet.gluon.contrib.estimator.event_handler.BatchEnd, mxnet.gluon.contrib.estimator.event_handler.EpochEnd

Stop conditions to stop training Stop training if maximum number of batches or epochs reached.

Parameters
  • max_epoch (int, default None) – Number of maximum epochs to train.

  • max_batch (int, default None) – Number of maximum batches to train.

class ValidationHandler(val_data, eval_fn, epoch_period=1, batch_period=None, priority=-1000, event_handlers=None)[source]

Bases: mxnet.gluon.contrib.estimator.event_handler.TrainBegin, mxnet.gluon.contrib.estimator.event_handler.BatchEnd, mxnet.gluon.contrib.estimator.event_handler.EpochEnd

Validation Handler that evaluate model on validation dataset

ValidationHandler takes validation dataset, an evaluation function, metrics to be evaluated, and how often to run the validation. You can provide custom evaluation function or use the one provided my Estimator

Parameters
  • val_data (DataLoader) – Validation data set to run evaluation.

  • eval_fn (function) – A function defines how to run evaluation and calculate loss and metrics.

  • epoch_period (int, default 1) – How often to run validation at epoch end, by default ValidationHandler validate every epoch.

  • batch_period (int, default None) – How often to run validation at batch end, by default ValidationHandler does not validate at batch end.

  • priority (scalar, default -1000) – Priority level of the ValidationHandler. Priority level is sorted in ascending order. The lower the number is, the higher priority level the handler is.

  • event_handlers (EventHandler or list of EventHandlers) – List of EventHandler to apply during validaiton. This argument is used by self.eval_fn function in order to process customized event handlers.