Note
Click here to download the full example code
ONNX model exampleΒΆ
- An example of a sequential network that solves a supervised classification task used
as an OpenML flow. Uses MXNet as backend.
import openml.extensions.onnx
import onnx
import os
import mxnet as mx
import mxnet.contrib.onnx as onnx_mxnet
ONNX_FILE_PATH_DEFAULT = 'model.onnx'
MXNET_PARAMS_PATH_DEFAULT = './model-0001.params'
MXNET_SYMBOL_PATH_DEFAULT = './model-symbol.json'
Obtain task with training data
task = openml.tasks.get_task(10101)
X, y = task.get_X_and_y()
train_indices, test_indices = task.get_train_test_split_indices(
repeat=0, fold=0, sample=0)
X_train = X[train_indices]
y_train = y[train_indices]
X_test = X[test_indices]
y_test = y[test_indices]
Compute shapes of input and output
output_length = len(task.class_labels)
input_length = X_train.shape[1]
# Create MXNet Variables
data = mx.sym.var('data')
label = mx.sym.var('softmax_label')
# Create the MXNet Module API model
bnorm = mx.sym.BatchNorm(data=data)
fc1 = mx.sym.FullyConnected(data=bnorm, num_hidden=1024)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
drop1 = mx.sym.Dropout(data=act1, p=0.4)
fc2 = mx.sym.FullyConnected(data=drop1, num_hidden=output_length)
mlp = mx.sym.SoftmaxOutput(data=fc2, name='softmax', label=label)
mlp_model = mx.mod.Module(symbol=mlp, context=mx.cpu())
data_shapes = [('data', X_train.shape)]
label_shapes = [('softmax_label', y_train.shape)]
# Bind and initialize parameters
mlp_model.bind(data_shapes=data_shapes, label_shapes=label_shapes)
mlp_model.init_params(mx.init.Xavier())
# Save the parameters and symbol to files
mlp_model.save_params(MXNET_PARAMS_PATH_DEFAULT)
mlp.save(MXNET_SYMBOL_PATH_DEFAULT)
# Export the ONNX specification of the model, using the parameters and symbol files
onnx_mxnet.export_model(
sym=MXNET_SYMBOL_PATH_DEFAULT,
params=MXNET_PARAMS_PATH_DEFAULT,
input_shape=[(64, input_length)],
onnx_file_path=ONNX_FILE_PATH_DEFAULT)
Load ONNX file and remove files
model = onnx.load_model(ONNX_FILE_PATH_DEFAULT)
if os.path.exists(MXNET_PARAMS_PATH_DEFAULT):
os.remove(MXNET_PARAMS_PATH_DEFAULT)
if os.path.exists(MXNET_SYMBOL_PATH_DEFAULT):
os.remove(MXNET_SYMBOL_PATH_DEFAULT)
if os.path.exists(ONNX_FILE_PATH_DEFAULT):
os.remove(ONNX_FILE_PATH_DEFAULT)
Run the model on the task (requires an API key).
run = openml.runs.run_model_on_task(model, task, avoid_duplicate_runs=False)
# Publish the experiment on OpenML (optional, requires an API key).
run.publish()
print('URL for run: %s/run/%d' % (openml.config.server, run.run_id))
Total running time of the script: ( 0 minutes 0.000 seconds)