|
- """ Model training pipeline """
- import logging
- import os
- import moxing as mox
- import time
- import json
-
-
- import mindspore as ms
- from mindspore import Tensor, context
- from mindspore.communication import get_group_size, get_rank, init
-
- from mindspore.context import ParallelMode
-
- from mindcv.data import create_dataset, create_loader, create_transforms
- from mindcv.loss import create_loss
- from mindcv.models import create_model
- from mindcv.optim import create_optimizer
- from mindcv.scheduler import create_scheduler
- from mindcv.utils import (
- AllReduceSum,
- StateMonitor,
- create_trainer,
- get_metrics,
- require_customized_train_step,
- set_logger,
- set_seed,
- )
-
- from config import parse_args, save_args # isort: skip
-
- logger = logging.getLogger("mindcv.train")
-
-
- def train(args):
- """main train function"""
-
- ms.set_context(mode=args.mode)
- if args.distribute:
- init()
- device_num = get_group_size()
- rank_id = get_rank()
- ms.set_auto_parallel_context(
- device_num=device_num,
- parallel_mode="data_parallel",
- gradients_mean=True,
- # we should but cannot set parameter_broadcast=True, which will cause error on gpu.
- )
- else:
- device_num = None
- rank_id = None
-
- set_seed(args.seed)
- set_logger(name="mindcv", output_dir=args.ckpt_save_dir, rank=rank_id, color=False)
- logger.info(
- "We recommend installing `termcolor` via `pip install termcolor` "
- "and setup logger by `set_logger(..., color=True)`"
- )
-
- # create dataset
- dataset_train = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.train_split,
- shuffle=args.shuffle,
- num_samples=args.num_samples,
- num_shards=device_num,
- shard_id=rank_id,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download,
- num_aug_repeats=args.aug_repeats,
- )
-
- if args.num_classes is None:
- num_classes = dataset_train.num_classes()
- else:
- num_classes = args.num_classes
-
- # create transforms
- transform_list = create_transforms(
- dataset_name=args.dataset,
- is_training=True,
- image_resize=args.image_resize,
- scale=args.scale,
- ratio=args.ratio,
- hflip=args.hflip,
- vflip=args.vflip,
- color_jitter=args.color_jitter,
- interpolation=args.interpolation,
- auto_augment=args.auto_augment,
- mean=args.mean,
- std=args.std,
- re_prob=args.re_prob,
- re_scale=args.re_scale,
- re_ratio=args.re_ratio,
- re_value=args.re_value,
- re_max_attempts=args.re_max_attempts,
- )
-
- # load dataset
- loader_train = create_loader(
- dataset=dataset_train,
- batch_size=args.batch_size,
- drop_remainder=args.drop_remainder,
- is_training=True,
- mixup=args.mixup,
- cutmix=args.cutmix,
- cutmix_prob=args.cutmix_prob,
- num_classes=num_classes,
- transform=transform_list,
- num_parallel_workers=args.num_parallel_workers,
- )
-
- if args.val_while_train:
- dataset_eval = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.val_split,
- num_shards=device_num,
- shard_id=rank_id,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download,
- )
-
- transform_list_eval = create_transforms(
- dataset_name=args.dataset,
- is_training=False,
- image_resize=args.image_resize,
- crop_pct=args.crop_pct,
- interpolation=args.interpolation,
- mean=args.mean,
- std=args.std,
- )
-
- loader_eval = create_loader(
- dataset=dataset_eval,
- batch_size=args.batch_size,
- drop_remainder=False,
- is_training=False,
- transform=transform_list_eval,
- num_parallel_workers=args.num_parallel_workers,
- )
- # validation dataset count
- eval_count = dataset_eval.get_dataset_size()
- if args.distribute:
- all_reduce = AllReduceSum()
- eval_count = all_reduce(Tensor(eval_count, ms.int32))
- else:
- loader_eval = None
- eval_count = None
-
- num_batches = loader_train.get_dataset_size()
- # Train dataset count
- train_count = dataset_train.get_dataset_size()
- if args.distribute:
- all_reduce = AllReduceSum()
- train_count = all_reduce(Tensor(train_count, ms.int32))
-
- # create model
- network = create_model(
- model_name=args.model,
- num_classes=num_classes,
- in_channels=args.in_channels,
- drop_rate=args.drop_rate,
- drop_path_rate=args.drop_path_rate,
- pretrained=args.pretrained,
- checkpoint_path=args.ckpt_path,
- ema=args.ema,
- )
-
- num_params = sum([param.size for param in network.get_parameters()])
-
- # create loss
- loss = create_loss(
- name=args.loss,
- reduction=args.reduction,
- label_smoothing=args.label_smoothing,
- aux_factor=args.aux_factor,
- )
-
- # create learning rate schedule
- lr_scheduler = create_scheduler(
- num_batches,
- scheduler=args.scheduler,
- lr=args.lr,
- min_lr=args.min_lr,
- warmup_epochs=args.warmup_epochs,
- warmup_factor=args.warmup_factor,
- decay_epochs=args.decay_epochs,
- decay_rate=args.decay_rate,
- milestones=args.multi_step_decay_milestones,
- num_epochs=args.epoch_size,
- num_cycles=args.num_cycles,
- cycle_decay=args.cycle_decay,
- lr_epoch_stair=args.lr_epoch_stair,
- )
-
- # resume training if ckpt_path is given
- if args.ckpt_path != "" and args.resume_opt:
- opt_ckpt_path = os.path.join(args.ckpt_save_dir, f"optim_{args.model}.ckpt")
- else:
- opt_ckpt_path = ""
-
- # create optimizer
- # TODO: consistent naming opt, name, dataset_name
- if (
- args.loss_scale_type == "fixed"
- and args.drop_overflow_update is False
- and not require_customized_train_step(args.ema, args.clip_grad, args.gradient_accumulation_steps)
- ):
- optimizer_loss_scale = args.loss_scale
- else:
- optimizer_loss_scale = 1.0
- optimizer = create_optimizer(
- network.trainable_params(),
- opt=args.opt,
- lr=lr_scheduler,
- weight_decay=args.weight_decay,
- momentum=args.momentum,
- nesterov=args.use_nesterov,
- filter_bias_and_bn=args.filter_bias_and_bn,
- loss_scale=optimizer_loss_scale,
- checkpoint_path=opt_ckpt_path,
- eps=args.eps,
- )
-
- # Define eval metrics.
- metrics = get_metrics(num_classes)
-
- # create trainer
- trainer = create_trainer(
- network,
- loss,
- optimizer,
- metrics,
- amp_level=args.amp_level,
- loss_scale_type=args.loss_scale_type,
- loss_scale=args.loss_scale,
- drop_overflow_update=args.drop_overflow_update,
- ema=args.ema,
- ema_decay=args.ema_decay,
- clip_grad=args.clip_grad,
- clip_value=args.clip_value,
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- )
-
- # callback
- # save checkpoint, summary training loss
- # record val acc and do model selection if val dataset is available
- begin_step = 0
- begin_epoch = 0
- if args.ckpt_path != "":
- begin_step = optimizer.global_step.asnumpy()[0]
- begin_epoch = args.ckpt_path.split("/")[-1].split("-")[1].split("_")[0]
- begin_epoch = int(begin_epoch)
-
- summary_dir = f"./{args.ckpt_save_dir}/summary"
- assert (
- args.ckpt_save_policy != "top_k" or args.val_while_train is True
- ), "ckpt_save_policy is top_k, val_while_train must be True."
- state_cb = StateMonitor(
- trainer,
- model_name=args.model,
- model_ema=args.ema,
- last_epoch=begin_epoch,
- dataset_sink_mode=args.dataset_sink_mode,
- dataset_val=loader_eval,
- metric_name=list(metrics.keys()),
- val_interval=args.val_interval,
- ckpt_save_dir=args.ckpt_save_dir,
- ckpt_save_interval=args.ckpt_save_interval,
- ckpt_save_policy=args.ckpt_save_policy,
- ckpt_keep_max=args.keep_checkpoint_max,
- summary_dir=summary_dir,
- log_interval=args.log_interval,
- rank_id=rank_id,
- device_num=device_num,
- )
-
- callbacks = [state_cb]
- essential_cfg_msg = "\n".join(
- [
- "Essential Experiment Configurations:",
- f"MindSpore mode[GRAPH(0)/PYNATIVE(1)]: {args.mode}",
- f"Distributed mode: {args.distribute}",
- f"Number of devices: {device_num if device_num is not None else 1}",
- f"Number of training samples: {train_count}",
- f"Number of validation samples: {eval_count}",
- f"Number of classes: {num_classes}",
- f"Number of batches: {num_batches}",
- f"Batch size: {args.batch_size}",
- f"Auto augment: {args.auto_augment}",
- f"MixUp: {args.mixup}",
- f"CutMix: {args.cutmix}",
- f"Model: {args.model}",
- f"Model parameters: {num_params}",
- f"Number of epochs: {args.epoch_size}",
- f"Optimizer: {args.opt}",
- f"Learning rate: {args.lr}",
- f"LR Scheduler: {args.scheduler}",
- f"Momentum: {args.momentum}",
- f"Weight decay: {args.weight_decay}",
- f"Auto mixed precision: {args.amp_level}",
- f"Loss scale: {args.loss_scale}({args.loss_scale_type})",
- ]
- )
- logger.info(essential_cfg_msg)
- save_args(args, os.path.join(args.ckpt_save_dir, f"{args.model}.yaml"), rank_id)
-
- if args.ckpt_path != "":
- logger.info(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}")
- else:
- logger.info("Start training")
-
- trainer.train(args.epoch_size, loader_train, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode)
-
-
- def C2netMultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- print(multi_data_url)
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- zipfile_path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- try:
- mox.file.copy(multi_data_json[i]["dataset_url"], zipfile_path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],zipfile_path))
- #get filename and unzip the dataset
- filename = os.path.splitext(multi_data_json[i]["dataset_name"])[0]
- filePath = data_dir + "/" + filename
- if not os.path.exists(filePath):
- os.makedirs(filePath)
- os.system("unzip {} -d {}".format(zipfile_path, filePath))
-
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], zipfile_path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
-
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- C2netMultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- C2netMultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
-
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
-
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
-
-
-
- if __name__ == '__main__':
- args = parse_args()
- # args, unknown = parse_args.parse_known_args()
-
- # modelarts
- data_dir = '/cache/dataset'
- train_dir = '/cache/output'
- ckpt_url = '/cache/checkpoint.ckpt'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- print(args.multi_data_url)
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- train(args)
-
- UploadToQizhi(train_dir,args.train_url)
-
|