|
- #https://mxnet.apache.org/versions/1.9.1/api/python/docs/tutorials/deploy/export/onnx.html
-
- import mxnet as mx
- from mxnet import gluon
- import argparse, os, sys, stat
- import numpy as np
-
- parser = argparse.ArgumentParser(description='MxNet ONNX Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=64,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
-
-
- if __name__ == "__main__":
- args = parser.parse_args()
-
- #os.chmod('/model', stat.S_IRWXU)
-
- model_name = args.model.split('-')[0]
- model_path = '/tmp/dataset/' + model_name
- output_file = '/tmp/output/' + model_name + '.onnx'
- sym = model_path + '-symbol.json'
- params = model_path + '-0000.params'
-
- input_shape = (args.n, args.c, args.h, args.w)
-
- os.system('ls -a /tmp/dataset/')
- print('find all your model files above.')
- print('---------------------------------------')
- print('input data shape:' , input_shape)
- print('---------------------------------------')
- print('loading model:', args.model)
-
- ctx = mx.gpu(0)
- net = gluon.SymbolBlock.imports(symbol_file = sym, input_names = ['data'], param_file=params,ctx=ctx)
- net.hybridize()
- mx.viz.print_summary(net(mx.sym.var('data')),shape={'data':input_shape})
-
- # 调用导出模型API。它返回转换后的onnx模型的路径
- converted_model_path = mx.onnx.export_model(sym, params, [input_shape], np.float32, output_file)
- print('Your onnx model is exported at:',converted_model_path)
|