|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
-
- import mmcv
- import numpy as np
- import torch
- from mmcv.parallel import MMDataParallel
- from mmcv.parallel.scatter_gather import scatter_kwargs
- from mmcv.runner import load_checkpoint, wrap_fp16_model
- from PIL import Image
-
- from mmseg.datasets import build_dataloader, build_dataset
- from mmseg.models import build_segmentor
-
-
- @torch.no_grad()
- def main(args):
-
- models = []
- gpu_ids = args.gpus
- configs = args.config
- ckpts = args.checkpoint
-
- cfg = mmcv.Config.fromfile(configs[0])
-
- if args.aug_test:
- cfg.data.test.pipeline[1].img_ratios = [
- 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
- ]
- cfg.data.test.pipeline[1].flip = True
- else:
- cfg.data.test.pipeline[1].img_ratios = [1.0]
- cfg.data.test.pipeline[1].flip = False
-
- torch.backends.cudnn.benchmark = True
-
- # build the dataloader
- dataset = build_dataset(cfg.data.test)
- data_loader = build_dataloader(
- dataset,
- samples_per_gpu=1,
- workers_per_gpu=4,
- dist=False,
- shuffle=False,
- )
-
- for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
- cfg = mmcv.Config.fromfile(config)
- cfg.model.pretrained = None
- cfg.data.test.test_mode = True
-
- model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
- if cfg.get('fp16', None):
- wrap_fp16_model(model)
- load_checkpoint(model, ckpt, map_location='cpu')
- torch.cuda.empty_cache()
- tmpdir = args.out
- mmcv.mkdir_or_exist(tmpdir)
- model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
- model.eval()
- models.append(model)
-
- dataset = data_loader.dataset
- prog_bar = mmcv.ProgressBar(len(dataset))
- loader_indices = data_loader.batch_sampler
- for batch_indices, data in zip(loader_indices, data_loader):
- result = []
-
- for model in models:
- x, _ = scatter_kwargs(
- inputs=data, kwargs=None, target_gpus=model.device_ids)
- if args.aug_test:
- logits = model.module.aug_test_logits(**x[0])
- else:
- logits = model.module.simple_test_logits(**x[0])
- result.append(logits)
-
- result_logits = 0
- for logit in result:
- result_logits += logit
-
- pred = result_logits.argmax(axis=1).squeeze()
- img_info = dataset.img_infos[batch_indices[0]]
- file_name = os.path.join(
- tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
- Image.fromarray(pred.astype(np.uint8)).save(file_name)
- prog_bar.update()
-
-
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Model Ensemble with logits result')
- parser.add_argument(
- '--config', type=str, nargs='+', help='ensemble config files path')
- parser.add_argument(
- '--checkpoint',
- type=str,
- nargs='+',
- help='ensemble checkpoint files path')
- parser.add_argument(
- '--aug-test',
- action='store_true',
- help='control ensemble aug-result or single-result (default)')
- parser.add_argument(
- '--out', type=str, default='results', help='the dir to save result')
- parser.add_argument(
- '--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')
-
- args = parser.parse_args()
- assert len(args.config) == len(args.checkpoint), \
- f'len(config) must equal len(checkpoint), ' \
- f'but len(config) = {len(args.config)} and' \
- f'len(checkpoint) = {len(args.checkpoint)}'
- assert args.out, "ensemble result out-dir can't be None"
- return args
-
-
- if __name__ == '__main__':
- args = parse_args()
- main(args)
|