|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os
- import os.path as osp
- import time
- import warnings
-
- import mmcv
- import torch
- from mmcv import Config, DictAction
- from mmcv.cnn import fuse_conv_bn
- from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
- from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
- wrap_fp16_model)
- from mmdet.apis import multi_gpu_test, single_gpu_test
- from mmdet.datasets import build_dataloader, replace_ImageToTensor
-
- from mmrotate.datasets import build_dataset
- from mmrotate.models import build_detector
- from mmrotate.utils import compat_cfg, setup_multi_processes
-
- # D:/mmrotate/configs/RotatedCSPDarknet/2-RotatedCSPDarknet_darknet.py
- # D:/mmrotate/train_result/epoch_47.pth
- # --eval
- # mAP
- # --work-dir
- # D:/mmrotate/test_result/
- # --out
- # OUT.pkl
-
- def parse_args():
- """Parse parameters."""
- parser = argparse.ArgumentParser(
- description='MMDet test (and eval) a model')
- parser.add_argument('--config',
- default='/code/configs/RotatedCSPDarknet/2-darknet_yoloxpan_KLD.py',
- help='test config file path')
- parser.add_argument('--checkpoint',
- default='/pretrainmodel/epoch_48.pth',
- help='checkpoint file')
- parser.add_argument(
- '--work-dir',
- default='/model/',
- help='the directory to save the file containing evaluation metrics') # 保存路径
- parser.add_argument('--out',
- default='/model/out.pkl',
- help='output result file in pickle format') # 输出.pkl文件结果, 生成混淆矩阵.
- parser.add_argument(
- '--fuse-conv-bn',
- action='store_true',
- help='Whether to fuse conv and bn, this will slightly increase'
- 'the inference speed')
- parser.add_argument(
- '--gpu-ids',
- type=int,
- nargs='+',
- help='ids of gpus to use '
- '(only applicable to non-distributed testing)')
- parser.add_argument(
- '--format-only',
- action='store_true',
- help='Format the output results without perform evaluation. It is' # 格式化输出结果,不进行评估。
- 'useful when you want to format the result to a specific format and ' # 当你想将结果格式化为特定的格式并提交给测试服务器的时候
- 'submit it to the test server')
- parser.add_argument(
- '--eval',
- type=str,
- default='mAP',
- nargs='+',
- help='evaluation metrics, which depends on the dataset, e.g., "bbox",' # 取决于数据集的评价指标。
- ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
- parser.add_argument('--show', action='store_true', help='show results')
- parser.add_argument(
- '--show-dir', help='directory where painted images will be saved')
- parser.add_argument(
- '--show-score-thr',
- type=float,
- default=0.3,
- help='score threshold (default: 0.3)')
- parser.add_argument(
- '--gpu-collect',
- action='store_true',
- help='whether to use gpu to collect results.')
- parser.add_argument(
- '--tmpdir',
- help='tmp directory used for collecting results from multiple '
- 'workers, available when gpu-collect is not specified')
- parser.add_argument(
- '--cfg-options',
- nargs='+',
- action=DictAction,
- help='override some settings in the used config, the key-value pair ' # 覆盖所用配置中的一些设置。 格式:key=value,
- 'in xxx=yyy format will be merged into config file. If the value to ' # key=[value1, value2]
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space ' # 请注意,引号是必要的,不允许有空白。
- 'is allowed.')
- parser.add_argument(
- '--eval-options',
- nargs='+',
- action=DictAction,
- help='custom options for evaluation, the key-value pair in xxx=yyy '
- 'format will be kwargs for dataset.evaluate() function')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher')
- parser.add_argument('--local_rank', type=int, default=0) # 分布式训练参数,单机单卡不用管。
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
-
- return args
-
-
- def main():
- args = parse_args()
- # assert中的表达式为真,执行代码。
- assert args.out or args.eval or args.format_only or args.show \
- or args.show_dir, \
- ('Please specify at least one operation (save/eval/format/show the '
- 'results / save the results) with the argument "--out", "--eval"'
- ', "--format-only", "--show" or "--show-dir"')
-
- if args.eval and args.format_only:
- raise ValueError('--eval and --format_only cannot be both specified')
-
- if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
- raise ValueError('The output file must be a pkl file.')
-
-
- cfg = Config.fromfile(args.config)
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- cfg = compat_cfg(cfg)
-
- # set multi-process settings
- setup_multi_processes(cfg)
-
- # set cudnn_benchmark
- # torch.backends控制PyTorch支持的各种后台的行为。
- # cudnn.benchmark使cuDNN测试多种卷积算法,选择最快的算法。
- if cfg.get('cudnn_benchmark', False):
- torch.backends.cudnn.benchmark = True
-
- # model预训练相关
- cfg.model.pretrained = None
- if cfg.model.get('neck'):
- if isinstance(cfg.model.neck, list): # neck是list
- for neck_cfg in cfg.model.neck:
- if neck_cfg.get('rfp_backbone'):
- if neck_cfg.rfp_backbone.get('pretrained'):
- neck_cfg.rfp_backbone.pretrained = None
- elif cfg.model.neck.get('rfp_backbone'): # neck不是list
- if cfg.model.neck.rfp_backbone.get('pretrained'):
- cfg.model.neck.rfp_backbone.pretrained = None
-
- # GPU
- if args.gpu_ids is not None:
- cfg.gpu_ids = args.gpu_ids
- else:
- cfg.gpu_ids = range(1)
-
- # init distributed env first, since logger depends on the dist info.
- if args.launcher == 'none':
- distributed = False
- if len(cfg.gpu_ids) > 1:
- warnings.warn(
- f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
- f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
- 'non-distribute testing time.')
- cfg.gpu_ids = cfg.gpu_ids[0:1]
- else:
- distributed = True
- init_dist(args.launcher, **cfg.dist_params)
-
- test_dataloader_default_args = dict(
- samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)
-
- # in case the test dataset is concatenated 在测试数据集被串联的情况下。
- if isinstance(cfg.data.test, dict): # cfg文件dotav1的data字典的test 是字典类型
- cfg.data.test.test_mode = True
- if 'samples_per_gpu' in cfg.data.test:
- warnings.warn('`samples_per_gpu` in `test` field of '
- 'data will be deprecated, you should'
- ' move it to `test_dataloader` field')
- test_dataloader_default_args['samples_per_gpu'] = \
- cfg.data.test.pop('samples_per_gpu')
- if test_dataloader_default_args['samples_per_gpu'] > 1:
- # Replace 'ImageToTensor' to 'DefaultFormatBundle'
- cfg.data.test.pipeline = replace_ImageToTensor(
- cfg.data.test.pipeline)
- elif isinstance(cfg.data.test, list):
- for ds_cfg in cfg.data.test:
- ds_cfg.test_mode = True
- if 'samples_per_gpu' in ds_cfg:
- warnings.warn('`samples_per_gpu` in `test` field of '
- 'data will be deprecated, you should'
- ' move it to `test_dataloader` field')
- samples_per_gpu = max(
- [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
- test_dataloader_default_args['samples_per_gpu'] = samples_per_gpu
- if samples_per_gpu > 1:
- for ds_cfg in cfg.data.test:
- ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
-
- # 测试集加载器cfg: {'samples_per_gpu': 1, 'workers_per_gpu': 2, 'dist': False, 'shuffle': False}
- test_loader_cfg = {
- **test_dataloader_default_args,
- **cfg.data.get('test_dataloader', {})
- }
-
- rank, _ = get_dist_info()
- # allows not to create
- if args.work_dir is not None and rank == 0:
- mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
-
- # build the dataloader
- dataset = build_dataset(cfg.data.test) # dotav1的参数
- data_loader = build_dataloader(dataset, **test_loader_cfg)
-
- # build the model and load checkpoint
- cfg.model.train_cfg = None
- model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
- fp16_cfg = cfg.get('fp16', None)
- if fp16_cfg is not None: # None
- wrap_fp16_model(model) # 将FP32模型包装成FP16.
- checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
- if args.fuse_conv_bn: # False
- model = fuse_conv_bn(model) # 融合conv, bn.
- # old versions did not save class info in checkpoints, this walkaround is
- # for backward compatibility. 旧版本不在checkpoints中保存类别信息,这个方法是为了向后兼容。
- if 'CLASSES' in checkpoint.get('meta', {}):
- model.CLASSES = checkpoint['meta']['CLASSES']
- else:
- model.CLASSES = dataset.CLASSES
-
- if not distributed: # 单机的检测结果
- model = MMDataParallel(model, device_ids=cfg.gpu_ids) # 包装模型,利用多线程实现分发并行机制,把数据平均分发到各个 GPU 上。
- outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
- args.show_score_thr)
- else: # 分布式的检测结果
- model = MMDistributedDataParallel(
- model.cuda(),
- device_ids=[torch.cuda.current_device()],
- broadcast_buffers=False)
- outputs = multi_gpu_test(model, data_loader, args.tmpdir,
- args.gpu_collect)
-
- rank, _ = get_dist_info()
- if rank == 0:
- if args.out: # --out
- print(f'\nwriting results to {args.out}')
- mmcv.dump(outputs, args.out)
- kwargs = {} if args.eval_options is None else args.eval_options
- if args.format_only: # --format
- dataset.format_results(outputs, **kwargs)
- if args.eval: # --eval
- eval_kwargs = cfg.get('evaluation', {}).copy()
- # hard-code way to remove EvalHook args. 用硬编码的方式删除EvalHook的args。
- for key in [
- 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
- 'rule', 'dynamic_intervals'
- ]:
- eval_kwargs.pop(key, None)
- eval_kwargs.update(dict(metric=args.eval, **kwargs)) # 去除hook后的评估参数
- metric = dataset.evaluate(outputs, **eval_kwargs)
- print(metric)
- metric_dict = dict(config=args.config, metric=metric)
- if args.work_dir is not None and rank == 0:
- mmcv.dump(metric_dict, json_file)
-
-
- if __name__ == '__main__':
- main()
|