Running inference on MXNet/Gluon from an ONNX model

Open Neural Network Exchange (ONNX) provides an open source format for AI models. It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.

In this tutorial we will:

  • learn how to load a pre-trained .onnx model file into MXNet/Gluon
  • learn how to test this model using the sample input/output
  • learn how to test the model on custom images


To run the tutorial you will need to have installed the following python modules:

import numpy as np
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
from mxnet import gluon, nd
%matplotlib inline
import matplotlib.pyplot as plt
import tarfile, os
import json
import logging

Downloading supporting files

These are images and a vizualisation script

image_folder = "images"
utils_file = "" # contain utils function to plot nice visualization
image_net_labels_file = "image_net_labels.json"
images = ['apron.jpg', 'hammerheadshark.jpg', 'dog.jpg', 'wrench.jpg', 'dolphin.jpg', 'lotus.jpg']
base_url = "{}?raw=true"

for image in images:"{}/{}".format(image_folder, image)), fname=image,dirname=image_folder), fname=utils_file), fname=image_net_labels_file)

from utils import *

Downloading a model from the ONNX model zoo

We download a pre-trained model, in our case the GoogleNet model, trained on ImageNet from the ONNX model zoo. The model comes packaged in an archive tar.gz file containing an model.onnx model file.

base_url = "" 
current_model = "bvlc_googlenet"
model_folder = "model"
archive = "{}.tar.gz".format(current_model)
archive_file = os.path.join(model_folder, archive)
url = "{}{}".format(base_url, archive)

Download and extract pre-trained model, dirname = model_folder)
if not os.path.isdir(os.path.join(model_folder, current_model)):
    print('Extracting model...')
    tar =, "r:gz")

The models have been pre-trained on ImageNet, let’s load the label mapping of the 1000 classes.

categories = json.load(open(image_net_labels_file, 'r'))

Loading the model into MXNet Gluon

onnx_path = os.path.join(model_folder, current_model, "model.onnx")

We get the symbol and parameter objects

sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)

We pick a context, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU.

ctx = mx.cpu()

We obtain the data names of the inputs to the model by using the model metadata API:

model_metadata = onnx_mxnet.get_model_metadata(onnx_path)
{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],
 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}
data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]


And load them into a MXNet Gluon symbol block.

net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))
net_params = net.collect_params()
for param in arg_params:
    if param in net_params:
        net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
    if param in net_params:
        net_params[param]._load_init(aux_params[param], ctx=ctx)

We can now cache the computational graph through hybridization to gain some performance


We can visualize the network (requires graphviz installed)

mx.visualization.plot_network(sym,  node_attrs={"shape":"oval","fixedsize":"false"})


This is a helper function to run M batches of data of batch-size N through the net and collate the outputs into an array of shape (K, 1000) where K=MxN is the total number of examples (mumber of batches x batch-size) run through the network.

def run_batch(net, data):
    results = []
    for batch in data:
        outputs = net(batch)
        results.extend([o for o in outputs.asnumpy()])
    return np.array(results)

Test using real images

TOP_P = 3 # How many top guesses we show in the visualization

Transform function to set the data into the format the network expects, (N, 3, 224, 224) where N is the batch size.

def transform(img):
    return np.expand_dims(np.transpose(img, (2,0,1)),axis=0).astype(np.float32)

We load two sets of images in memory

image_net_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['apron', 'hammerheadshark','dog']]
caltech101_images = [plt.imread('{}/{}.jpg'.format(image_folder, path)) for path in ['wrench', 'dolphin','lotus']]
images = image_net_images + caltech101_images

And run them as a batch through the network to get the predictions

batch = nd.array(np.concatenate([transform(img) for img in images], axis=0), ctx=ctx)
result = run_batch(net, [batch])
plot_predictions(image_net_images, result[:3], categories, TOP_P)


Well done! Looks like it is doing a pretty good job at classifying pictures when the category is a ImageNet label

Let’s now see the results on the 3 other images

plot_predictions(caltech101_images, result[3:7], categories, TOP_P)


Hmm, not so good... Even though predictions are close, they are not accurate, which is due to the fact that the ImageNet dataset does not contain wrench, dolphin, or lotus categories and our network has been trained on ImageNet.

Lucky for us, the Caltech101 dataset has them, let’s see how we can fine-tune our network to classify these categories correctly.

We show that in our next tutorial: