|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import os.path as osp
- import warnings
- from typing import Iterable, Optional, Union
-
- import matplotlib.pyplot as plt
- import mmcv
- import numpy as np
- import onnxruntime as ort
- import torch
- from mmcv.ops import get_onnxruntime_op_path
- from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
- save_trt_engine)
-
- from mmseg.apis.inference import LoadImage
- from mmseg.datasets import DATASETS
- from mmseg.datasets.pipelines import Compose
-
-
- def get_GiB(x: int):
- """return x GiB."""
- return x * (1 << 30)
-
-
- def _prepare_input_img(img_path: str,
- test_pipeline: Iterable[dict],
- shape: Optional[Iterable] = None,
- rescale_shape: Optional[Iterable] = None) -> dict:
- # build the data pipeline
- if shape is not None:
- test_pipeline[1]['img_scale'] = (shape[1], shape[0])
- test_pipeline[1]['transforms'][0]['keep_ratio'] = False
- test_pipeline = [LoadImage()] + test_pipeline[1:]
- test_pipeline = Compose(test_pipeline)
- # prepare data
- data = dict(img=img_path)
- data = test_pipeline(data)
- imgs = data['img']
- img_metas = [i.data for i in data['img_metas']]
-
- if rescale_shape is not None:
- for img_meta in img_metas:
- img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
-
- mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
-
- return mm_inputs
-
-
- def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
- # update img and its meta list
- N = img_list[0].size(0)
- img_meta = img_meta_list[0][0]
- img_shape = img_meta['img_shape']
- ori_shape = img_meta['ori_shape']
- pad_shape = img_meta['pad_shape']
- new_img_meta_list = [[{
- 'img_shape':
- img_shape,
- 'ori_shape':
- ori_shape,
- 'pad_shape':
- pad_shape,
- 'filename':
- img_meta['filename'],
- 'scale_factor':
- (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
- 'flip':
- False,
- } for _ in range(N)]]
-
- return img_list, new_img_meta_list
-
-
- def show_result_pyplot(img: Union[str, np.ndarray],
- result: np.ndarray,
- palette: Optional[Iterable] = None,
- fig_size: Iterable[int] = (15, 10),
- opacity: float = 0.5,
- title: str = '',
- block: bool = True):
- img = mmcv.imread(img)
- img = img.copy()
- seg = result[0]
- seg = mmcv.imresize(seg, img.shape[:2][::-1])
- palette = np.array(palette)
- assert palette.shape[1] == 3
- assert len(palette.shape) == 2
- assert 0 < opacity <= 1.0
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
- for label, color in enumerate(palette):
- color_seg[seg == label, :] = color
- # convert to BGR
- color_seg = color_seg[..., ::-1]
-
- img = img * (1 - opacity) + color_seg * opacity
- img = img.astype(np.uint8)
-
- plt.figure(figsize=fig_size)
- plt.imshow(mmcv.bgr2rgb(img))
- plt.title(title)
- plt.tight_layout()
- plt.show(block=block)
-
-
- def onnx2tensorrt(onnx_file: str,
- trt_file: str,
- config: dict,
- input_config: dict,
- fp16: bool = False,
- verify: bool = False,
- show: bool = False,
- dataset: str = 'CityscapesDataset',
- workspace_size: int = 1,
- verbose: bool = False):
- import tensorrt as trt
- min_shape = input_config['min_shape']
- max_shape = input_config['max_shape']
- # create trt engine and wrapper
- opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
- max_workspace_size = get_GiB(workspace_size)
- trt_engine = onnx2trt(
- onnx_file,
- opt_shape_dict,
- log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
- fp16_mode=fp16,
- max_workspace_size=max_workspace_size)
- save_dir, _ = osp.split(trt_file)
- if save_dir:
- os.makedirs(save_dir, exist_ok=True)
- save_trt_engine(trt_engine, trt_file)
- print(f'Successfully created TensorRT engine: {trt_file}')
-
- if verify:
- inputs = _prepare_input_img(
- input_config['input_path'],
- config.data.test.pipeline,
- shape=min_shape[2:])
-
- imgs = inputs['imgs']
- img_metas = inputs['img_metas']
- img_list = [img[None, :] for img in imgs]
- img_meta_list = [[img_meta] for img_meta in img_metas]
- # update img_meta
- img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
-
- if max_shape[0] > 1:
- # concate flip image for batch test
- flip_img_list = [_.flip(-1) for _ in img_list]
- img_list = [
- torch.cat((ori_img, flip_img), 0)
- for ori_img, flip_img in zip(img_list, flip_img_list)
- ]
-
- # Get results from ONNXRuntime
- ort_custom_op_path = get_onnxruntime_op_path()
- session_options = ort.SessionOptions()
- if osp.exists(ort_custom_op_path):
- session_options.register_custom_ops_library(ort_custom_op_path)
- sess = ort.InferenceSession(onnx_file, session_options)
- sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
- onnx_output = sess.run(['output'],
- {'input': img_list[0].detach().numpy()})[0][0]
-
- # Get results from TensorRT
- trt_model = TRTWraper(trt_file, ['input'], ['output'])
- with torch.no_grad():
- trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
- trt_output = trt_outputs['output'][0].cpu().detach().numpy()
-
- if show:
- dataset = DATASETS.get(dataset)
- assert dataset is not None
- palette = dataset.PALETTE
-
- show_result_pyplot(
- input_config['input_path'],
- (onnx_output[0].astype(np.uint8), ),
- palette=palette,
- title='ONNXRuntime',
- block=False)
- show_result_pyplot(
- input_config['input_path'], (trt_output[0].astype(np.uint8), ),
- palette=palette,
- title='TensorRT')
-
- np.testing.assert_allclose(
- onnx_output, trt_output, rtol=1e-03, atol=1e-05)
- print('TensorRT and ONNXRuntime output all close.')
-
-
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Convert MMSegmentation models from ONNX to TensorRT')
- parser.add_argument('config', help='Config file of the model')
- parser.add_argument('model', help='Path to the input ONNX model')
- parser.add_argument(
- '--trt-file', type=str, help='Path to the output TensorRT engine')
- parser.add_argument(
- '--max-shape',
- type=int,
- nargs=4,
- default=[1, 3, 400, 600],
- help='Maximum shape of model input.')
- parser.add_argument(
- '--min-shape',
- type=int,
- nargs=4,
- default=[1, 3, 400, 600],
- help='Minimum shape of model input.')
- parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
- parser.add_argument(
- '--workspace-size',
- type=int,
- default=1,
- help='Max workspace size in GiB')
- parser.add_argument(
- '--input-img', type=str, default='', help='Image for test')
- parser.add_argument(
- '--show', action='store_true', help='Whether to show output results')
- parser.add_argument(
- '--dataset',
- type=str,
- default='CityscapesDataset',
- help='Dataset name')
- parser.add_argument(
- '--verify',
- action='store_true',
- help='Verify the outputs of ONNXRuntime and TensorRT')
- parser.add_argument(
- '--verbose',
- action='store_true',
- help='Whether to verbose logging messages while creating \
- TensorRT engine.')
- args = parser.parse_args()
- return args
-
-
- if __name__ == '__main__':
-
- assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
- args = parse_args()
-
- if not args.input_img:
- args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
-
- # check arguments
- assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
- assert osp.exists(args.model), \
- 'ONNX model {} not found.'.format(args.model)
- assert args.workspace_size >= 0, 'Workspace size less than 0.'
- assert DATASETS.get(args.dataset) is not None, \
- 'Dataset {} does not found.'.format(args.dataset)
- for max_value, min_value in zip(args.max_shape, args.min_shape):
- assert max_value >= min_value, \
- 'max_shape should be larger than min shape'
-
- input_config = {
- 'min_shape': args.min_shape,
- 'max_shape': args.max_shape,
- 'input_path': args.input_img
- }
-
- cfg = mmcv.Config.fromfile(args.config)
- onnx2tensorrt(
- args.model,
- args.trt_file,
- cfg,
- input_config,
- fp16=args.fp16,
- verify=args.verify,
- show=args.show,
- dataset=args.dataset,
- workspace_size=args.workspace_size,
- verbose=args.verbose)
-
- # Following strings of text style are from colorama package
- bright_style, reset_style = '\x1b[1m', '\x1b[0m'
- red_text, blue_text = '\x1b[31m', '\x1b[34m'
- white_background = '\x1b[107m'
-
- msg = white_background + bright_style + red_text
- msg += 'DeprecationWarning: This tool will be deprecated in future. '
- msg += blue_text + 'Welcome to use the unified model deployment toolbox '
- msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
- msg += reset_style
- warnings.warn(msg)
|