# Model

The model API provides convenient high-level interface to do training and predicting on a network described using the symbolic API.

# MXNet.mx.AbstractModelType.

AbstractModel


The abstract super type of all models in MXNet.jl.

# MXNet.mx.FeedForwardType.

FeedForward


The feedforward model provides convenient interface to train and predict on feedforward architectures like multi-layer MLP, ConvNets, etc. There is no explicitly handling of time index, but it is relatively easy to implement unrolled RNN / LSTM under this framework (TODO: add example). For models that handles sequential data explicitly, please use TODO...

# MXNet.mx.FeedForwardMethod.

FeedForward(arch :: SymbolicNode, ctx)


Arguments:

• arch: the architecture of the network constructed using the symbolic API.
• ctx: the devices on which this model should do computation. It could be a single Context or a list of Context objects. In the latter case, data parallelization will be used for training. If no context is provided, the default context cpu() will be used.

# MXNet.mx.predictMethod.

predict(self, data; overwrite=false, callback=nothing)


Predict using an existing model. The model should be already initialized, or trained or loaded from a checkpoint. There is an overloaded function that allows to pass the callback as the first argument, so it is possible to do

predict(model, data) do batch_output
# consume or write batch_output to file
end


Arguments:

• self::FeedForward: the model.
• data::AbstractDataProvider: the data to perform prediction on.
• overwrite::Bool: an Executor is initialized the first time predict is called. The memory allocation of the Executor depends on the mini-batch size of the test data provider. If you call predict twice with data provider of the same batch-size, then the executor can be potentially be re-used. So, if overwrite is false, we will try to re-use, and raise an error if batch-size changed. If overwrite is true (the default), a new Executor will be created to replace the old one.
• verbosity::Integer: Determines the verbosity of the print messages. Higher numbers leads to more verbose printing. Acceptable values are - 0: Do not print anything during prediction - 1: Print allocation information during prediction

Note

Prediction is computationally much less costly than training, so the bottleneck sometimes becomes the IO for copying mini-batches of data. Since there is no concern about convergence in prediction, it is better to set the mini-batch size as large as possible (limited by your device memory) if prediction speed is a concern.

For the same reason, currently prediction will only use the first device even if multiple devices are provided to construct the model.

Note

If you perform further after prediction. The weights are not automatically synchronized if overwrite is set to false and the old predictor is re-used. In this case setting overwrite to true (the default) will re-initialize the predictor the next time you call predict and synchronize the weights again.

# MXNet.mx._split_inputsMethod.

Get a split of batch_size into n_split pieces for data parallelization. Returns a vector of length n_split, with each entry a UnitRange{Int} indicating the slice index for that piece.

# MXNet.mx.fitMethod.

fit(model::FeedForward, optimizer, data; kwargs...)


Train the model on data with the optimizer.

• model::FeedForward: the model to be trained.
• optimizer::AbstractOptimizer: the optimization algorithm to use.
• data::AbstractDataProvider: the training data provider.
• n_epoch::Int: default 10, the number of full data-passes to run.
• eval_data::AbstractDataProvider: keyword argument, default nothing. The data provider for the validation set.
• eval_metric::AbstractEvalMetric: keyword argument, default Accuracy(). The metric used to evaluate the training performance. If eval_data is provided, the same metric is also calculated on the validation set.
• kvstore: keyword argument, default :local. The key-value store used to synchronize gradients and parameters when multiple devices are used for training. :type kvstore: KVStore or Symbol
• initializer::AbstractInitializer: keyword argument, default UniformInitializer(0.01).
• force_init::Bool: keyword argument, default false. By default, the random initialization using the provided initializer will be skipped if the model weights already exists, maybe from a previous call to train or an explicit call to init_model or load_checkpoint. When this option is set, it will always do random initialization at the begining of training.
• callbacks::Vector{AbstractCallback}: keyword argument, default []. Callbacks to be invoked at each epoch or mini-batch, see AbstractCallback.
• verbosity::Int: Determines the verbosity of the print messages. Higher numbers leads to more verbose printing. Acceptable values are - 0: Do not print anything during training - 1: Print starting and final messages - 2: Print one time messages and a message at the start of each epoch - 3: Print a summary of the training and validation accuracy for each epoch
• η_decay::Symbol: :epoch or :batch, decay learning rate on epoch or batch.

# MXNet.mx.init_modelMethod.

init_model(self, initializer; overwrite=false, input_shapes...)


Initialize the weights in the model.

This method will be called automatically when training a model. So there is usually no need to call this method unless one needs to inspect a model with only randomly initialized weights.

Arguments:

• self::FeedForward: the model to be initialized.
• initializer::AbstractInitializer: an initializer describing how the weights should be initialized.
• overwrite::Bool: keyword argument, force initialization even when weights already exists.
• input_shapes: the shape of all data and label inputs to this model, given as keyword arguments. For example, data=(28,28,1,100), label=(100,).

# MXNet.mx.load_checkpointMethod.

load_checkpoint(prefix, epoch, ::mx.FeedForward; context)


Load a mx.FeedForward model from the checkpoint prefix, epoch and optionally provide a context.

# MXNet.mx.trainMethod.

train(model :: FeedForward, ...)


Alias to fit.