|
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import copy
- import os
- import os.path as osp
- import time
- import warnings
-
- import mmcv
- import torch
- import torch.distributed as dist
- from mmcv import Config, DictAction
- from mmcv.runner import get_dist_info, init_dist
- from mmcv.utils import get_git_hash
- from mmdet import __version__
- from mmdet.apis import init_random_seed, set_random_seed
-
- from mmrotate.apis import train_detector
- from mmrotate.datasets import build_dataset
- from mmrotate.models import build_detector
- from mmrotate.utils import collect_env, get_root_logger, setup_multi_processes
-
- # 项目代码存储在 /code 中,数据集存储在 /dataset 中,选择的模型存储在 /pretrainmodel 中,调试输出请存储在 /model 中以供后续下载。
- # python train.py configs/RotatedCSPDarknet/0-RotatedDarknet_darknet.py --work-dir /model/ --resume-from /pretrainmodel/epoch_12.pth
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Train a detector')
- parser.add_argument('--config',
- default='/code/configs/RotatedCSPDarknet/0-p5-small_cspdarknet_KLD_yoloxpan.py',
- # default='/code/configs/rotated_retinanet/rotated_retinanet_hbb_kld_r50_fpn_1x_dota_oc.py',
- help='train config file path') # cfg
- parser.add_argument('--work-dir',
- default='/model/',
- help='the dir to save logs and models') # 训练结果保存路径,--开头是可选参数。
- parser.add_argument(
- '--resume-from', help='the checkpoint file to resume from') # 继续训练的权重文件
- parser.add_argument(
- '--auto-resume',
- action='store_true',
- help='resume from the latest checkpoint automatically') # 自动从最新的权重文件恢复训练
- parser.add_argument(
- '--no-validate',
- action='store_true', # 带该参数时,参数值存为True。
- help='whether not to evaluate the checkpoint during training') # 是否在训练期间不评估检查点
- group_gpus = parser.add_mutually_exclusive_group() # 创建互斥组参数,其参数不能同时出现。
- group_gpus.add_argument(
- '--gpus',
- type=int, # 期望收到的数据类型
- help='number of gpus to use '
- '(only applicable to non-distributed training)')
- group_gpus.add_argument(
- '--gpu-ids',
- type=int,
- nargs='+',
- help='ids of gpus to use '
- '(only applicable to non-distributed training)')
- parser.add_argument('--seed', type=int, default=None, help='random seed')
- parser.add_argument(
- '--diff-seed',
- action='store_true',
- help='Whether or not set different seeds for different ranks')
- parser.add_argument(
- '--deterministic',
- action='store_true',
- help='whether to set deterministic options for CUDNN backend.') # 是否让cudnn自己选择最优计算结构。
- parser.add_argument(
- '--cfg-options',
- nargs='+', # 应读取的参数个数。+号表示1或多个参数。
- action=DictAction,
- help='override some settings in the used config, the key-value pair '
- 'in xxx=yyy format will be merged into config file. If the value to '
- 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
- 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
- 'Note that the quotation marks are necessary and that no white space '
- 'is allowed.')
- parser.add_argument(
- '--launcher',
- choices=['none', 'pytorch', 'slurm', 'mpi'],
- default='none',
- help='job launcher') # 线程启动器
- parser.add_argument('--local_rank', type=int, default=0)
- args = parser.parse_args()
- if 'LOCAL_RANK' not in os.environ:
- os.environ['LOCAL_RANK'] = str(args.local_rank)
-
- return args
-
-
- def main():
- args = parse_args()
-
- cfg = Config.fromfile(args.config) # 配置初始化
- if args.cfg_options is not None:
- cfg.merge_from_dict(args.cfg_options)
-
- # set multi-process settings
- setup_multi_processes(cfg)
-
- # set cudnn_benchmark
- if cfg.get('cudnn_benchmark', False):
- torch.backends.cudnn.benchmark = True
-
- # work_dir is determined in this priority: CLI > segment in file > filename
- if args.work_dir is not None: # 如果指定work-dir
- # update configs according to CLI args if args.work_dir is not None
- cfg.work_dir = args.work_dir
- elif cfg.get('work_dir', None) is None: # 默认使用cfg的dir
- # use config filename as default work_dir if cfg.work_dir is None
- cfg.work_dir = osp.join('./work_dirs',
- osp.splitext(osp.basename(args.config))[0])
- if args.resume_from is not None: # 恢复上次训练
- cfg.resume_from = args.resume_from
- cfg.auto_resume = args.auto_resume
- if args.gpu_ids is not None: # 是指定gpu-ids
- cfg.gpu_ids = args.gpu_ids
- else: # 还是指定gpus数量
- cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
-
- # init distributed env first, since logger depends on the dist info.
- if args.launcher == 'none': # 没有指定launcher,故禁用分布式训练。
- distributed = False
- if len(cfg.gpu_ids) > 1:
- warnings.warn(
- f'We treat {cfg.gpu_ids} as gpu-ids, and reset to '
- f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in '
- 'non-distribute training time.')
- cfg.gpu_ids = cfg.gpu_ids[0:1]
- else:
- distributed = True
- init_dist(args.launcher, **cfg.dist_params)
- # re-set gpu_ids with distributed training mode
- _, world_size = get_dist_info()
- cfg.gpu_ids = range(world_size)
-
- # create work_dir
- mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
- # dump config
- cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) # work_dir+config文件名
- # init the logger before other steps
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
- logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
-
- # init the meta dict to record some important information such as
- # environment info and seed, which will be logged 初始化元数据字典,以记录重要信息。
- meta = dict()
- # log env info
- env_info_dict = collect_env()
- env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
- dash_line = '-' * 60 + '\n'
- logger.info('Environment info:\n' + dash_line + env_info + '\n' +
- dash_line)
- meta['env_info'] = env_info
- meta['config'] = cfg.pretty_text
- # log some basic info
- logger.info(f'Distributed training: {distributed}')
- logger.info(f'Config:\n{cfg.pretty_text}')
-
- # set random seeds
- seed = init_random_seed(args.seed)
- seed = seed + dist.get_rank() if args.diff_seed else seed
- logger.info(f'Set random seed to {seed}, '
- f'deterministic: {args.deterministic}')
- set_random_seed(seed, deterministic=args.deterministic)
- cfg.seed = seed
- meta['seed'] = seed
- meta['exp_name'] = osp.basename(args.config)
-
- model = build_detector(
- cfg.model,
- train_cfg=cfg.get('train_cfg'),
- test_cfg=cfg.get('test_cfg')) # 初始化model,更具体的见mmorotate/apis/train.py
- model.init_weights()
-
- datasets = [build_dataset(cfg.data.train)] # 初始化datasets
- if len(cfg.workflow) == 2:
- val_dataset = copy.deepcopy(cfg.data.val)
- val_dataset.pipeline = cfg.data.train.pipeline
- datasets.append(build_dataset(val_dataset))
- if cfg.checkpoint_config is not None:
- # save mmdet version, config file content and class names in
- # checkpoints as meta data
- cfg.checkpoint_config.meta = dict(
- mmdet_version=__version__ + get_git_hash()[:7],
- CLASSES=datasets[0].CLASSES)
- # add an attribute for visualization convenience 可视化
- model.CLASSES = datasets[0].CLASSES
- train_detector(
- model,
- datasets,
- cfg,
- distributed=distributed,
- validate=(not args.no_validate),
- timestamp=timestamp,
- meta=meta)
-
-
- if __name__ == '__main__':
- main()
|