|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import warnings
- from pathlib import Path
-
- import mmcv
- import numpy as np
- from mmcv import Config, DictAction
-
- from mmseg.datasets.builder import build_dataset
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Browse a dataset')
- parser.add_argument('config', help='train config file path')
- parser.add_argument(
- '--show-origin',
- default=False,
- action='store_true',
- help='if True, omit all augmentation in pipeline,'
- ' show origin image and seg map')
- parser.add_argument(
- '--skip-type',
- type=str,
- nargs='+',
- default=['DefaultFormatBundle', 'Normalize', 'Collect'],
- help='skip some useless pipeline,if `show-origin` is true, '
- 'all pipeline except `Load` will be skipped')
- parser.add_argument(
- '--output-dir',
- default='./output',
- type=str,
- help='If there is no display interface, you can save it')
- parser.add_argument('--show', default=False, action='store_true')
- parser.add_argument(
- '--show-interval',
- type=int,
- default=999,
- help='the interval of show (ms)')
- parser.add_argument(
- '--opacity',
- type=float,
- default=0.5,
- help='the opacity of semantic map')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- args = parser.parse_args()
- return args
-
-
- def imshow_semantic(img,
- seg,
- class_names,
- palette=None,
- win_name='',
- show=False,
- wait_time=0,
- out_file=None,
- opacity=0.5):
- """Draw `result` over `img`.
-
- Args:
- img (str or Tensor): The image to be displayed.
- seg (Tensor): The semantic segmentation results to draw over
- `img`.
- class_names (list[str]): Names of each classes.
- palette (list[list[int]]] | np.ndarray | None): The palette of
- segmentation map. If None is given, random palette will be
- generated. Default: None
- win_name (str): The window name.
- wait_time (int): Value of waitKey param.
- Default: 0.
- show (bool): Whether to show the image.
- Default: False.
- out_file (str or None): The filename to write the image.
- Default: None.
- opacity(float): Opacity of painted segmentation map.
- Default 0.5.
- Must be in (0, 1] range.
- Returns:
- img (Tensor): Only if not `show` or `out_file`
- """
- img = mmcv.imread(img)
- img = img.copy()
- if palette is None:
- palette = np.random.randint(0, 255, size=(len(class_names), 3))
- palette = np.array(palette)
- assert palette.shape[0] == len(class_names)
- assert palette.shape[1] == 3
- assert len(palette.shape) == 2
- assert 0 < opacity <= 1.0
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
- for label, color in enumerate(palette):
- color_seg[seg == label, :] = color
- # convert to BGR
- color_seg = color_seg[..., ::-1]
-
- img = img * (1 - opacity) + color_seg * opacity
- img = img.astype(np.uint8)
- # if out_file specified, do not show image in window
- if out_file is not None:
- show = False
-
- if show:
- mmcv.imshow(img, win_name, wait_time)
- if out_file is not None:
- mmcv.imwrite(img, out_file)
-
- if not (show or out_file):
- warnings.warn('show==False and out_file is not specified, only '
- 'result image will be returned')
- return img
-
-
- def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
- if show_origin is True:
- # only keep pipeline of Loading data and ann
- _data_cfg['pipeline'] = [
- x for x in _data_cfg.pipeline if 'Load' in x['type']
- ]
- else:
- _data_cfg['pipeline'] = [
- x for x in _data_cfg.pipeline if x['type'] not in skip_type
- ]
-
-
- def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
- cfg = Config.fromfile(config_path)
- if cfg_options is not None:
- cfg.merge_from_dict(cfg_options)
- train_data_cfg = cfg.data.train
- if isinstance(train_data_cfg, list):
- for _data_cfg in train_data_cfg:
- while 'dataset' in _data_cfg and _data_cfg[
- 'type'] != 'MultiImageMixDataset':
- _data_cfg = _data_cfg['dataset']
- if 'pipeline' in _data_cfg:
- _retrieve_data_cfg(_data_cfg, skip_type, show_origin)
- else:
- raise ValueError
- else:
- while 'dataset' in train_data_cfg and train_data_cfg[
- 'type'] != 'MultiImageMixDataset':
- train_data_cfg = train_data_cfg['dataset']
- _retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
- return cfg
-
-
- def main():
- args = parse_args()
- cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
- args.show_origin)
- dataset = build_dataset(cfg.data.train)
- progress_bar = mmcv.ProgressBar(len(dataset))
- for item in dataset:
- filename = os.path.join(args.output_dir,
- Path(item['filename']).name
- ) if args.output_dir is not None else None
- imshow_semantic(
- item['img'],
- item['gt_semantic_seg'],
- dataset.CLASSES,
- dataset.PALETTE,
- show=args.show,
- wait_time=args.show_interval,
- out_file=filename,
- opacity=args.opacity,
- )
- progress_bar.update()
-
-
- if __name__ == '__main__':
- main()
|