Evaluation Metric API

Overview

This document lists all the evaluation metrics available to evaluate the performance of a learned model.

mxnet.metric Online evaluation metric module.

API Reference

Online evaluation metric module.

class mxnet.metric.EvalMetric(name, output_names=None, label_names=None, **kwargs)[source]

Base class for all evaluation metrics.

Note

This is a base class that provides common metric interfaces. One should not use this class directly, but instead create new metric classes that extend it.

Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.
get_config()[source]

Save configurations of metric. Can be recreated from configs with metric.create(**config)

update_dict(label, pred)[source]

Update the internal evaluation with named label and pred

Parameters:
  • labels (OrderedDict of str -> NDArray) – name to array mapping for labels.
  • preds (list of NDArray) – name to array mapping of predicted outputs.
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
reset()[source]

Resets the internal evaluation result to initial state.

get()[source]

Gets the current evaluation result.

Returns:
  • names (list of str) – Name of the metrics.
  • values (list of float) – Value of the evaluations.
get_name_value()[source]

Returns zipped name and value pairs.

Returns:A (name, value) tuple list.
Return type:list of tuples
mxnet.metric.create(metric, *args, **kwargs)[source]

Creates evaluation metric from metric names or instances of EvalMetric or a custom metric function.

Parameters:
  • metric (str or callable) –

    Specifies the metric to create. This argument must be one of the below:

    • Name of a metric.
    • An instance of EvalMetric.
    • A list, each element of which is a metric or a metric name.
    • An evaluation function that computes custom metric for a given batch of labels and predictions.
  • *args (list) – Additional arguments to metric constructor. Only used when metric is str.
  • **kwargs (dict) – Additional arguments to metric constructor. Only used when metric is str

Examples

>>> def custom_metric(label, pred):
...     return np.mean(np.abs(label - pred))
...
>>> metric1 = mx.metric.create('acc')
>>> metric2 = mx.metric.create(custom_metric)
>>> metric3 = mx.metric.create([metric1, metric2, 'rmse'])
class mxnet.metric.CompositeEvalMetric(metrics=None, name='composite', output_names=None, label_names=None)[source]

Manages multiple evaluation metrics.

Parameters:
  • metrics (list of EvalMetric) – List of child metrics.
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> eval_metrics_1 = mx.metric.Accuracy()
>>> eval_metrics_2 = mx.metric.F1()
>>> eval_metrics = mx.metric.CompositeEvalMetric()
>>> for child_metric in [eval_metrics_1, eval_metrics_2]:
>>>     eval_metrics.add(child_metric)
>>> eval_metrics.update(labels = labels, preds = predicts)
>>> print eval_metrics.get()
(['accuracy', 'f1'], [0.6666666666666666, 0.8])
add(metric)[source]

Adds a child metric.

Parameters:metric – A metric instance.
get_metric(index)[source]

Returns a child metric.

Parameters:index (int) – Index of child metric in the list of metrics.
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
reset()[source]

Resets the internal evaluation result to initial state.

get()[source]

Returns the current evaluation result.

Returns:
  • names (list of str) – Name of the metrics.
  • values (list of float) – Value of the evaluations.
class mxnet.metric.Accuracy(axis=1, name='accuracy', output_names=None, label_names=None)[source]

Computes accuracy classification score.

The accuracy score is defined as

\[\text{accuracy}(y, \hat{y}) = \frac{1}{n} \sum_{i=0}^{n-1} \text{1}(\hat{y_i} == y_i)\]
Parameters:
  • axis (int, default=1) – The axis that represents classes
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> acc = mx.metric.Accuracy()
>>> acc.update(preds = predicts, labels = labels)
>>> print acc.get()
('accuracy', 0.6666666666666666)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data with class indices as values, one per sample.
  • preds (list of NDArray) – Prediction values for samples. Each prediction value can either be the class index, or a vector of likelihoods for all classes.
class mxnet.metric.TopKAccuracy(top_k=1, name='top_k_accuracy', output_names=None, label_names=None)[source]

Computes top k predictions accuracy.

TopKAccuracy differs from Accuracy in that it considers the prediction to be True as long as the ground truth label is in the top K predicated labels.

If top_k = 1, then TopKAccuracy is identical to Accuracy.

Parameters:
  • top_k (int) – Whether targets are in top k predictions.
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> np.random.seed(999)
>>> top_k = 3
>>> labels = [mx.nd.array([2, 6, 9, 2, 3, 4, 7, 8, 9, 6])]
>>> predicts = [mx.nd.array(np.random.rand(10, 10))]
>>> acc = mx.metric.TopKAccuracy(top_k=top_k)
>>> acc.update(labels, predicts)
>>> print acc.get()
('top_k_accuracy', 0.3)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.F1(name='f1', output_names=None, label_names=None)[source]

Computes the F1 score of a binary classification problem.

The F1 score is equivalent to weighted average of the precision and recall, where the best value is 1.0 and the worst value is 0.0. The formula for F1 score is:

F1 = 2 * (precision * recall) / (precision + recall)

The formula for precision and recall is:

precision = true_positives / (true_positives + false_positives)
recall    = true_positives / (true_positives + false_negatives)

Note

This F1 score only supports binary classification.

Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0., 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0., 1., 1.])]
>>> acc = mx.metric.F1()
>>> acc.update(preds = predicts, labels = labels)
>>> print acc.get()
('f1', 0.8)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.Perplexity(ignore_label, axis=-1, name='perplexity', output_names=None, label_names=None)[source]

Computes perplexity.

Perplexity is a measurement of how well a probability distribution or model predicts a sample. A low perplexity indicates the model is good at predicting the sample.

The perplexity of a model q is defined as

\[b^{\big(-\frac{1}{N} \sum_{i=1}^N \log_b q(x_i) \big)} = \exp \big(-\frac{1}{N} \sum_{i=1}^N \log q(x_i)\big)\]

where we let b = e.

\(q(x_i)\) is the predicted value of its ground truth label on sample \(x_i\).

For example, we have three samples \(x_1, x_2, x_3\) and their labels are \([0, 1, 1]\). Suppose our model predicts \(q(x_1) = p(y_1 = 0 | x_1) = 0.3\) and \(q(x_2) = 1.0\), \(q(x_3) = 0.6\). The perplexity of model q is \(exp\big(-(\log 0.3 + \log 1.0 + \log 0.6) / 3\big) = 1.77109762852\).

Parameters:
  • ignore_label (int or None) – Index of invalid label to ignore when counting. By default, sets to -1. If set to None, it will include all entries.
  • axis (int (default -1)) – The axis from prediction that was used to compute softmax. By default use the last axis.
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> perp = mx.metric.Perplexity(ignore_label=None)
>>> perp.update(labels, predicts)
>>> print perp.get()
('Perplexity', 1.7710976285155853)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
get()[source]

Returns the current evaluation result.

Returns:Representing name of the metric and evaluation result.
Return type:Tuple of (str, float)
class mxnet.metric.MAE(name='mae', output_names=None, label_names=None)[source]

Computes Mean Absolute Error (MAE) loss.

The mean absolute error is given by

\[\frac{\sum_i^n |y_i - \hat{y}_i|}{n}\]
Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
>>> labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
>>> mean_absolute_error = mx.metric.MAE()
>>> mean_absolute_error.update(labels = labels, preds = predicts)
>>> print mean_absolute_error.get()
('mae', 0.5)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.MSE(name='mse', output_names=None, label_names=None)[source]

Computes Mean Squared Error (MSE) loss.

The mean squared error is given by

\[\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}\]
Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
>>> labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
>>> mean_squared_error = mx.metric.MSE()
>>> mean_squared_error.update(labels = labels, preds = predicts)
>>> print mean_squared_error.get()
('mse', 0.375)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.RMSE(name='rmse', output_names=None, label_names=None)[source]

Computes Root Mean Squred Error (RMSE) loss.

The root mean squared error is given by

\[\sqrt{\frac{\sum_i^n (y_i - \hat{y}_i)^2}{n}}\]
Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
>>> labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
>>> root_mean_squared_error = mx.metric.RMSE()
>>> root_mean_squared_error.update(labels = labels, preds = predicts)
>>> print root_mean_squared_error.get()
('rmse', 0.612372457981)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.CrossEntropy(eps=1e-12, name='cross-entropy', output_names=None, label_names=None)[source]

Computes Cross Entropy loss.

The cross entropy over a batch of sample size \(N\) is given by

\[-\sum_{n=1}^{N}\sum_{k=1}^{K}t_{nk}\log (y_{nk}),\]

where \(t_{nk}=1\) if and only if sample \(n\) belongs to class \(k\). \(y_{nk}\) denotes the probability of sample \(n\) belonging to class \(k\).

Parameters:
  • eps (float) – Cross Entropy loss is undefined for predicted value is 0 or 1, so predicted values are added with the small constant.
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> ce = mx.metric.CrossEntropy()
>>> ce.update(labels, predicts)
>>> print ce.get()
('cross-entropy', 0.57159948348999023)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.NegativeLogLikelihood(eps=1e-12, name='nll-loss', output_names=None, label_names=None)[source]

Computes the negative log-likelihood loss.

The negative log-likelihoodd loss over a batch of sample size \(N\) is given by

\[-\sum_{n=1}^{N}\sum_{k=1}^{K}t_{nk}\log (y_{nk}),\]

where \(K\) is the number of classes, \(y_{nk}\) is the prediceted probability for \(k\)-th class for \(n\)-th sample. \(t_{nk}=1\) if and only if sample \(n\) belongs to class \(k\).

Parameters:
  • eps (float) – Negative log-likelihood loss is undefined for predicted value is 0, so predicted values are added with the small constant.
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([0, 1, 1])]
>>> nll_loss = mx.metric.NegativeLogLikelihood()
>>> nll_loss.update(labels, predicts)
>>> print nll_loss.get()
('nll-loss', 0.57159948348999023)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.PearsonCorrelation(name='pearsonr', output_names=None, label_names=None)[source]

Computes Pearson correlation.

The pearson correlation is given by

\[\frac{cov(y, \hat{y})}{\sigma{y}\sigma{\hat{y}}}\]
Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels   = [mx.nd.array([[1, 0], [0, 1], [0, 1]])]
>>> pr = mx.metric.PearsonCorrelation()
>>> pr.update(labels, predicts)
>>> print pr.get()
('pearson-correlation', 0.42163704544016178)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
class mxnet.metric.Loss(name='loss', output_names=None, label_names=None)[source]

Dummy metric for directly printing loss.

Parameters:
  • name (str) – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.
class mxnet.metric.Torch(name='torch', output_names=None, label_names=None)[source]

Dummy metric for torch criterions.

class mxnet.metric.Caffe(name='caffe', output_names=None, label_names=None)[source]

Dummy metric for caffe criterions.

class mxnet.metric.CustomMetric(feval, name=None, allow_extra_outputs=False, output_names=None, label_names=None)[source]

Computes a customized evaluation metric.

The feval function can return a tuple of (sum_metric, num_inst) or return an int sum_metric.

Parameters:
  • feval (callable(label, pred)) – Customized evaluation function.
  • name (str) – The name of the metric. (the default is None).
  • allow_extra_outputs (bool, optional) – If true, the prediction outputs can have extra outputs. This is useful in RNN, where the states are also produced in outputs for forwarding. (the default is False).
  • name – Name of this metric instance for display.
  • output_names (list of str, or None) – Name of predictions that should be used when updating with update_dict. By default include all predictions.
  • label_names (list of str, or None) – Name of labels that should be used when updating with update_dict. By default include all labels.

Examples

>>> predicts = [mx.nd.array(np.array([3, -0.5, 2, 7]).reshape(4,1))]
>>> labels = [mx.nd.array(np.array([2.5, 0.0, 2, 8]).reshape(4,1))]
>>> feval = lambda x, y : (x + y).mean()
>>> eval_metrics = mx.metric.CustomMetric(feval=feval)
>>> eval_metrics.update(labels, predicts)
>>> print eval_metrics.get()
('custom()', 6.0)
update(labels, preds)[source]

Updates the internal evaluation result.

Parameters:
  • labels (list of NDArray) – The labels of the data.
  • preds (list of NDArray) – Predicted values.
mxnet.metric.np(numpy_feval, name=None, allow_extra_outputs=False)[source]

Creates a custom evaluation metric that receives its inputs as numpy arrays.

Parameters:
  • numpy_feval (callable(label, pred)) – Custom evaluation function that receives labels and predictions for a minibatch as numpy arrays and returns the corresponding custom metric as a floating point number.
  • name (str, optional) – Name of the custom metric.
  • allow_extra_outputs (bool, optional) – Whether prediction output is allowed to have extra outputs. This is useful in cases like RNN where states are also part of output which can then be fed back to the RNN in the next step. By default, extra outputs are not allowed.
Returns:

Custom metric corresponding to the provided labels and predictions.

Return type:

float

Example

>>> def custom_metric(label, pred):
...     return np.mean(np.abs(label-pred))
...
>>> metric = mx.metric.np(custom_metric)