|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import warnings
- from functools import partial
-
- import mmcv
- import numpy as np
- import onnxruntime as rt
- import torch
- import torch._C
- import torch.serialization
- from mmcv import DictAction
- from mmcv.onnx import register_extra_symbolics
- from mmcv.runner import load_checkpoint
- from torch import nn
-
- from mmseg.apis import show_result_pyplot
- from mmseg.apis.inference import LoadImage
- from mmseg.datasets.pipelines import Compose
- from mmseg.models import build_segmentor
- from mmseg.ops import resize
-
- torch.manual_seed(3)
-
-
- def _convert_batchnorm(module):
- module_output = module
- if isinstance(module, torch.nn.SyncBatchNorm):
- module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
- module.momentum, module.affine,
- module.track_running_stats)
- if module.affine:
- module_output.weight.data = module.weight.data.clone().detach()
- module_output.bias.data = module.bias.data.clone().detach()
- # keep requires_grad unchanged
- module_output.weight.requires_grad = module.weight.requires_grad
- module_output.bias.requires_grad = module.bias.requires_grad
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- for name, child in module.named_children():
- module_output.add_module(name, _convert_batchnorm(child))
- del module
- return module_output
-
-
- def _demo_mm_inputs(input_shape, num_classes):
- """Create a superset of inputs needed to run test or train batches.
-
- Args:
- input_shape (tuple):
- input batch dimensions
- num_classes (int):
- number of semantic classes
- """
- (N, C, H, W) = input_shape
- rng = np.random.RandomState(0)
- imgs = rng.rand(*input_shape)
- segs = rng.randint(
- low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
- img_metas = [{
- 'img_shape': (H, W, C),
- 'ori_shape': (H, W, C),
- 'pad_shape': (H, W, C),
- 'filename': '<demo>.png',
- 'scale_factor': 1.0,
- 'flip': False,
- } for _ in range(N)]
- mm_inputs = {
- 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
- 'img_metas': img_metas,
- 'gt_semantic_seg': torch.LongTensor(segs)
- }
- return mm_inputs
-
-
- def _prepare_input_img(img_path,
- test_pipeline,
- shape=None,
- rescale_shape=None):
- # build the data pipeline
- if shape is not None:
- test_pipeline[1]['img_scale'] = (shape[1], shape[0])
- test_pipeline[1]['transforms'][0]['keep_ratio'] = False
- test_pipeline = [LoadImage()] + test_pipeline[1:]
- test_pipeline = Compose(test_pipeline)
- # prepare data
- data = dict(img=img_path)
- data = test_pipeline(data)
- imgs = data['img']
- img_metas = [i.data for i in data['img_metas']]
-
- if rescale_shape is not None:
- for img_meta in img_metas:
- img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
-
- mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
-
- return mm_inputs
-
-
- def _update_input_img(img_list, img_meta_list, update_ori_shape=False):
- # update img and its meta list
- N, C, H, W = img_list[0].shape
- img_meta = img_meta_list[0][0]
- img_shape = (H, W, C)
- if update_ori_shape:
- ori_shape = img_shape
- else:
- ori_shape = img_meta['ori_shape']
- pad_shape = img_shape
- new_img_meta_list = [[{
- 'img_shape':
- img_shape,
- 'ori_shape':
- ori_shape,
- 'pad_shape':
- pad_shape,
- 'filename':
- img_meta['filename'],
- 'scale_factor':
- (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
- 'flip':
- False,
- } for _ in range(N)]]
-
- return img_list, new_img_meta_list
-
-
- def pytorch2onnx(model,
- mm_inputs,
- opset_version=11,
- show=False,
- output_file='tmp.onnx',
- verify=False,
- dynamic_export=False):
- """Export Pytorch model to ONNX model and verify the outputs are same
- between Pytorch and ONNX.
-
- Args:
- model (nn.Module): Pytorch model we want to export.
- mm_inputs (dict): Contain the input tensors and img_metas information.
- opset_version (int): The onnx op version. Default: 11.
- show (bool): Whether print the computation graph. Default: False.
- output_file (string): The path to where we store the output ONNX model.
- Default: `tmp.onnx`.
- verify (bool): Whether compare the outputs between Pytorch and ONNX.
- Default: False.
- dynamic_export (bool): Whether to export ONNX with dynamic axis.
- Default: False.
- """
- model.cpu().eval()
- test_mode = model.test_cfg.mode
-
- if isinstance(model.decode_head, nn.ModuleList):
- num_classes = model.decode_head[-1].num_classes
- else:
- num_classes = model.decode_head.num_classes
-
- imgs = mm_inputs.pop('imgs')
- img_metas = mm_inputs.pop('img_metas')
-
- img_list = [img[None, :] for img in imgs]
- img_meta_list = [[img_meta] for img_meta in img_metas]
- # update img_meta
- img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
-
- # replace original forward function
- origin_forward = model.forward
- model.forward = partial(
- model.forward,
- img_metas=img_meta_list,
- return_loss=False,
- rescale=True)
- dynamic_axes = None
- if dynamic_export:
- if test_mode == 'slide':
- dynamic_axes = {'input': {0: 'batch'}, 'output': {1: 'batch'}}
- else:
- dynamic_axes = {
- 'input': {
- 0: 'batch',
- 2: 'height',
- 3: 'width'
- },
- 'output': {
- 1: 'batch',
- 2: 'height',
- 3: 'width'
- }
- }
-
- register_extra_symbolics(opset_version)
- with torch.no_grad():
- torch.onnx.export(
- model, (img_list, ),
- output_file,
- input_names=['input'],
- output_names=['output'],
- export_params=True,
- keep_initializers_as_inputs=False,
- verbose=show,
- opset_version=opset_version,
- dynamic_axes=dynamic_axes)
- print(f'Successfully exported ONNX model: {output_file}')
- model.forward = origin_forward
-
- if verify:
- # check by onnx
- import onnx
- onnx_model = onnx.load(output_file)
- onnx.checker.check_model(onnx_model)
-
- if dynamic_export and test_mode == 'whole':
- # scale image for dynamic shape test
- img_list = [resize(_, scale_factor=1.5) for _ in img_list]
- # concate flip image for batch test
- flip_img_list = [_.flip(-1) for _ in img_list]
- img_list = [
- torch.cat((ori_img, flip_img), 0)
- for ori_img, flip_img in zip(img_list, flip_img_list)
- ]
-
- # update img_meta
- img_list, img_meta_list = _update_input_img(
- img_list, img_meta_list, test_mode == 'whole')
-
- # check the numerical value
- # get pytorch output
- with torch.no_grad():
- pytorch_result = model(img_list, img_meta_list, return_loss=False)
- pytorch_result = np.stack(pytorch_result, 0)
-
- # get onnx output
- input_all = [node.name for node in onnx_model.graph.input]
- input_initializer = [
- node.name for node in onnx_model.graph.initializer
- ]
- net_feed_input = list(set(input_all) - set(input_initializer))
- assert (len(net_feed_input) == 1)
- sess = rt.InferenceSession(output_file)
- onnx_result = sess.run(
- None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
- # show segmentation results
- if show:
- import os.path as osp
-
- import cv2
- img = img_meta_list[0][0]['filename']
- if not osp.exists(img):
- img = imgs[0][:3, ...].permute(1, 2, 0) * 255
- img = img.detach().numpy().astype(np.uint8)
- ori_shape = img.shape[:2]
- else:
- ori_shape = LoadImage()({'img': img})['ori_shape']
-
- # resize onnx_result to ori_shape
- onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
- (ori_shape[1], ori_shape[0]))
- show_result_pyplot(
- model,
- img, (onnx_result_, ),
- palette=model.PALETTE,
- block=False,
- title='ONNXRuntime',
- opacity=0.5)
-
- # resize pytorch_result to ori_shape
- pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
- (ori_shape[1], ori_shape[0]))
- show_result_pyplot(
- model,
- img, (pytorch_result_, ),
- title='PyTorch',
- palette=model.PALETTE,
- opacity=0.5)
- # compare results
- np.testing.assert_allclose(
- pytorch_result.astype(np.float32) / num_classes,
- onnx_result.astype(np.float32) / num_classes,
- rtol=1e-5,
- atol=1e-5,
- err_msg='The outputs are different between Pytorch and ONNX')
- print('The outputs are same between Pytorch and ONNX')
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
- parser.add_argument('config', help='test config file path')
- parser.add_argument('--checkpoint', help='checkpoint file', default=None)
- parser.add_argument(
- '--input-img', type=str, help='Images for input', default=None)
- parser.add_argument(
- '--show',
- action='store_true',
- help='show onnx graph and segmentation results')
- parser.add_argument(
- '--verify', action='store_true', help='verify the onnx model')
- parser.add_argument('--output-file', type=str, default='tmp.onnx')
- parser.add_argument('--opset-version', type=int, default=11)
- parser.add_argument(
- '--shape',
- type=int,
- nargs='+',
- default=None,
- help='input image height and width.')
- parser.add_argument(
- '--rescale_shape',
- type=int,
- nargs='+',
- default=None,
- help='output image rescale height and width, work for slide mode.')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='Override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--dynamic-export',
- action='store_true',
- help='Whether to export onnx with dynamic axis.')
- args = parser.parse_args()
- return args
-
-
- if __name__ == '__main__':
- args = parse_args()
-
- cfg = mmcv.Config.fromfile(args.config)
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
- cfg.model.pretrained = None
-
- if args.shape is None:
- img_scale = cfg.test_pipeline[1]['img_scale']
- input_shape = (1, 3, img_scale[1], img_scale[0])
- elif len(args.shape) == 1:
- input_shape = (1, 3, args.shape[0], args.shape[0])
- elif len(args.shape) == 2:
- input_shape = (
- 1,
- 3,
- ) + tuple(args.shape)
- else:
- raise ValueError('invalid input shape')
-
- test_mode = cfg.model.test_cfg.mode
-
- # build the model and load checkpoint
- cfg.model.train_cfg = None
- segmentor = build_segmentor(
- cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
- # convert SyncBN to BN
- segmentor = _convert_batchnorm(segmentor)
-
- if args.checkpoint:
- checkpoint = load_checkpoint(
- segmentor, args.checkpoint, map_location='cpu')
- segmentor.CLASSES = checkpoint['meta']['CLASSES']
- segmentor.PALETTE = checkpoint['meta']['PALETTE']
-
- # read input or create dummpy input
- if args.input_img is not None:
- preprocess_shape = (input_shape[2], input_shape[3])
- rescale_shape = None
- if args.rescale_shape is not None:
- rescale_shape = [args.rescale_shape[0], args.rescale_shape[1]]
- mm_inputs = _prepare_input_img(
- args.input_img,
- cfg.data.test.pipeline,
- shape=preprocess_shape,
- rescale_shape=rescale_shape)
- else:
- if isinstance(segmentor.decode_head, nn.ModuleList):
- num_classes = segmentor.decode_head[-1].num_classes
- else:
- num_classes = segmentor.decode_head.num_classes
- mm_inputs = _demo_mm_inputs(input_shape, num_classes)
-
- # convert model to onnx file
- pytorch2onnx(
- segmentor,
- mm_inputs,
- opset_version=args.opset_version,
- show=args.show,
- output_file=args.output_file,
- verify=args.verify,
- dynamic_export=args.dynamic_export)
-
- # Following strings of text style are from colorama package
- bright_style, reset_style = '\x1b[1m', '\x1b[0m'
- red_text, blue_text = '\x1b[31m', '\x1b[34m'
- white_background = '\x1b[107m'
-
- msg = white_background + bright_style + red_text
- msg += 'DeprecationWarning: This tool will be deprecated in future. '
- msg += blue_text + 'Welcome to use the unified model deployment toolbox '
- msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
- msg += reset_style
- warnings.warn(msg)
|