ONNX model exampleΒΆ

An example of a sequential network that solves a supervised classification task used

as an OpenML flow. Uses MXNet as backend.

import openml.extensions.onnx
import onnx
import os
import mxnet as mx
import mxnet.contrib.onnx as onnx_mxnet

ONNX_FILE_PATH_DEFAULT = 'model.onnx'
MXNET_PARAMS_PATH_DEFAULT = './model-0001.params'
MXNET_SYMBOL_PATH_DEFAULT = './model-symbol.json'

Obtain task with training data

task = openml.tasks.get_task(10101)
X, y = task.get_X_and_y()
train_indices, test_indices = task.get_train_test_split_indices(
    repeat=0, fold=0, sample=0)
X_train = X[train_indices]
y_train = y[train_indices]
X_test = X[test_indices]
y_test = y[test_indices]

Compute shapes of input and output

output_length = len(task.class_labels)
input_length = X_train.shape[1]

# Create MXNet Variables
data = mx.sym.var('data')
label = mx.sym.var('softmax_label')

# Create the MXNet Module API model
bnorm = mx.sym.BatchNorm(data=data)
fc1 = mx.sym.FullyConnected(data=bnorm, num_hidden=1024)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
drop1 = mx.sym.Dropout(data=act1, p=0.4)
fc2 = mx.sym.FullyConnected(data=drop1, num_hidden=output_length)
mlp = mx.sym.SoftmaxOutput(data=fc2, name='softmax', label=label)
mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())

data_shapes = [('data', X_train.shape)]
label_shapes = [('softmax_label', y_train.shape)]

# Bind and initialize parameters
mlp_model.bind(data_shapes=data_shapes, label_shapes=label_shapes)
mlp_model.init_params(mx.init.Xavier())

# Save the parameters and symbol to files
mlp_model.save_params(MXNET_PARAMS_PATH_DEFAULT)
mlp.save(MXNET_SYMBOL_PATH_DEFAULT)

# Export the ONNX specification of the model, using the parameters and symbol files
onnx_mxnet.export_model(
    sym=MXNET_SYMBOL_PATH_DEFAULT,
    params=MXNET_PARAMS_PATH_DEFAULT,
    input_shape=[(64, input_length)],
    onnx_file_path=ONNX_FILE_PATH_DEFAULT)

Load ONNX file and remove files

model = onnx.load_model(ONNX_FILE_PATH_DEFAULT)
if os.path.exists(MXNET_PARAMS_PATH_DEFAULT):
    os.remove(MXNET_PARAMS_PATH_DEFAULT)
if os.path.exists(MXNET_SYMBOL_PATH_DEFAULT):
    os.remove(MXNET_SYMBOL_PATH_DEFAULT)
if os.path.exists(ONNX_FILE_PATH_DEFAULT):
    os.remove(ONNX_FILE_PATH_DEFAULT)

Run the model on the task (requires an API key).

run = openml.runs.run_model_on_task(model, task, avoid_duplicate_runs=False)
# Publish the experiment on OpenML (optional, requires an API key).
run.publish()

print('URL for run: %s/run/%d' % (openml.config.server, run.run_id))

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery