|
- import os
- import random
- import time
- import numpy as np
-
- import mindspore as ms
- from mindspore.context import ParallelMode
- from mindspore import context, nn, ops, Tensor
- from mindspore.communication.management import init, get_rank, get_group_size
- import mindspore.dataset.transforms as transforms
-
- from config import parse_args
-
- from mindcv.models import create_model
- from mindcv.data import create_dataset, create_transforms, create_loader
- from mindcv.loss import create_loss
- from mindcv.optim import create_optimizer
- from mindcv.scheduler import create_scheduler
- from mindcv.utils.grad import value_and_grad
-
- def set_seed(seed=2):
- np.random.seed(seed)
- random.seed(seed)
- ms.set_seed(seed)
-
- def train(args):
- set_seed()
-
- if args.enable_modelarts:
- from mindcv.utils.modelarts import sync_data
- args.data_dir = "/cache/data/"
- os.makedirs(args.data_dir, exist_ok=True)
- sync_data(args.data_url, args.data_dir)
- print(f"Data dir on modelarts:{os.listdir(args.data_dir)}")
- # update data_dir path for create_dataset
- if "imagenet" in os.listdir(args.data_dir):
- data_dir = os.path.join(args.data_dir, "imagenet")
- args.data_dir = data_dir
-
- rank, rank_size = args.rank, args.rank_size
- is_main_process = False
- if args.rank in [None, 0]:
- is_main_process = True
- #print('Rank id: ', args.rank, 'main? ', is_main_process)
-
- # Model
- #sync_bn = opt.sync_bn and context.get_context("device_target") == "Ascend" and rank_size > 1
- #model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors'), sync_bn=sync_bn, opt=opt) # create
-
- # create dataset and loader
- # TODO: ImageFolder interface test
- dataset_train = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.train_split,
- shuffle=args.shuffle,
- num_samples=args.num_samples,
- num_shards=args.rank_size,
- shard_id=args.rank,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download,
- num_aug_repeats=args.aug_repeats)
-
- if args.num_classes is None:
- num_classes = dataset_train.num_classes()
- else:
- num_classes = args.num_classes
- num_train = dataset_train.get_dataset_size()
-
- transform_list = create_transforms(
- dataset_name=args.dataset,
- is_training=True,
- image_resize=args.image_resize,
- scale=args.scale,
- ratio=args.ratio,
- hflip=args.hflip,
- vflip=args.vflip,
- color_jitter=args.color_jitter,
- interpolation=args.interpolation,
- auto_augment=args.auto_augment,
- mean=args.mean,
- std=args.std,
- re_prob=args.re_prob,
- re_scale=args.re_scale,
- re_ratio=args.re_ratio,
- re_value=args.re_value,
- re_max_attempts=args.re_max_attempts
- )
-
- target_transform = transforms.OneHot(num_classes) if args.loss == 'BCE' else None
-
- # dataset loader
- loader_train = create_loader(
- dataset=dataset_train,
- batch_size=args.batch_size,
- drop_remainder=args.drop_remainder,
- is_training=True,
- mixup=args.mixup,
- cutmix=args.cutmix,
- cutmix_prob=args.cutmix_prob,
- num_classes=num_classes,
- transform=transform_list,
- target_transform=target_transform,
- num_parallel_workers=args.num_parallel_workers,
- )
- num_batches = loader_train.get_dataset_size()
-
- # create model
- network = create_model(model_name=args.model,
- num_classes=num_classes,
- in_channels=args.in_channels,
- drop_rate=args.drop_rate,
- drop_path_rate=args.drop_path_rate,
- pretrained=args.pretrained,
- checkpoint_path=args.ckpt_path,
- use_ema=args.use_ema)
-
-
- # create loss
- loss = create_loss(name=args.loss,
- reduction=args.reduction,
- label_smoothing=args.label_smoothing,
- aux_factor=args.aux_factor)
-
- # create learning rate schedule
- lr_scheduler = create_scheduler(num_batches,
- scheduler=args.scheduler,
- lr=args.lr,
- min_lr=args.min_lr,
- warmup_epochs=args.warmup_epochs,
- warmup_factor=args.warmup_factor,
- decay_epochs=args.decay_epochs,
- decay_rate=args.decay_rate,
- milestones=args.multi_step_decay_milestones,
- num_epochs=args.epoch_size,
- lr_epoch_stair=args.lr_epoch_stair,
- num_cycles=args.num_cycles,
- cycle_decay=args.cycle_decay)
-
- # resume training if ckpt_path is given
- if args.ckpt_path != '' and args.resume_opt:
- opt_ckpt_path = os.path.join(args.ckpt_save_dir, f'optim_{args.model}.ckpt')
- else:
- opt_ckpt_path = ''
-
- # create optimizer
- optimizer = create_optimizer(network.trainable_params(),
- opt=args.opt,
- lr=lr_scheduler,
- weight_decay=args.weight_decay,
- momentum=args.momentum,
- nesterov=args.use_nesterov,
- filter_bias_and_bn=args.filter_bias_and_bn,
- loss_scale=args.loss_scale,
- checkpoint_path=opt_ckpt_path,
- eps=args.eps)
-
- # amp
- ms.amp.auto_mixed_precision(network, amp_level=args.amp_level)
- if args.distribute:
- mean = context.get_auto_parallel_context("gradients_mean")
- degree = context.get_auto_parallel_context("device_num")
- grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
- else:
- grad_reducer = ops.functional.identity
-
- if args.loss_scaler == "dynamic":
- from mindspore.amp import DynamicLossScaler
- loss_scaler = DynamicLossScaler(2**12, 2, 1000)
- elif args.loss_scaler == "static":
- from mindspore.amp import StaticLossScaler
- loss_scaler = StaticLossScaler(args.loss_scale)
- elif args.ms_loss_scaler == "none":
- loss_scaler = None
- else:
- raise NotImplementedError
-
- # Wrap Train Step
- overflow_still_update = False
- train_step = create_train_static_shape_fn_gradoperation(network, loss, optimizer, loss_scaler, grad_reducer,
- amp_level=args.amp_level,
- overflow_still_update=overflow_still_update,
- sens=args.ms_grad_sens)
- # Log
- if is_main_process:
- print(f"-" * 40)
- print(f"Num devices: {rank_size if rank_size is not None else 1} \n"
- f"Distributed mode: {args.distribute} \n")
- print(f"Num classes: {num_classes} \n"
- f"Batch size: {args.batch_size} \n"
- f"Auto augment: {args.auto_augment} \n"
- f"Model: {args.model} \n"
- f"Num epochs: {args.epoch_size} \n"
- f"Optimizer: {args.opt} \n"
- f"LR: {args.lr} \n"
- f"LR Scheduler: {args.scheduler}")
- print(f"-" * 40)
-
- log_txt_fp = os.path.join(args.ckpt_save_dir, 'result.log')
- with open(log_txt_fp, 'w', encoding="utf-8") as fp:
- header = 'Epoch\tTrainLoss\tLR\tTime\tTop1\tTop5\n'
- fp.write(header)
-
- # Training
- for epoch in range(args.epoch_size):
- network.set_train(True)
- optimizer.set_train(True)
- start = time.time()
- for step, (data, label) in enumerate(loader_train.create_tuple_iterator()):
- # Train with OOP
- loss, _, grads_finite = train_step(data, label)
-
- if loss_scaler:
- loss_scaler.adjust(grads_finite)
- if not grads_finite:
- print("overflow, loss scale adjust to ", loss_scaler.scale_value.asnumpy())
-
- # log train loss in the last step
- if step == num_batches - 1:
- if optimizer.dynamic_lr:
- cur_lr = optimizer.learning_rate(Tensor(step)).asnumpy()
- else:
- cur_lr = optimizer.learning_rate.asnumpy()
- train_time = time.time() - start
- print(f"Epoch:[{epoch+1}/{args.epoch_size}], "
- f"batch:[{step+1}/{num_batches}], "
- f"loss:{loss.asnumpy():.6f}, lr: {cur_lr:.7f}, time:{train_time:.6f}s")
-
- # save checkpoint & log
- if is_main_process:
- ckpt_path = os.path.join(args.ckpt_save_dir, f"{args.model}-{epoch + 1}_{num_batches}.ckpt")
- ms.save_checkpoint(network, ckpt_path)
- if args.enable_modelarts:
- sync_data(ckpt_path, args.train_url + "/weights/" + ckpt_path.split("/")[-1])
- #if ema:
- # sync_data(ema_ckpt_path, opt.train_url + "/weights/" + ema_ckpt_path.split("/")[-1])
-
- # TODO: eval every epoch
- if is_main_process:
- with open(log_txt_fp, 'a', encoding="utf-8") as fp:
- values = f'{epoch}\t{loss.asnumpy():.6f}\t{cur_lr:.7f}\t{time.time() - start:.5f}\t-\t-\n'
- fp.write(values)
- if args.enable_modelarts:
- sync_data(log_txt_fp, args.train_url + '/' + log_txt_fp.split("/")[-1])
- return 0
-
- def create_train_static_shape_fn_gradoperation(model, loss_fn, optimizer, loss_scaler, grad_reducer=None, rank_size=8,
- amp_level="O0", overflow_still_update=False, sens=1.0):
- # from mindspore.amp import all_finite # Bugs before MindSpore 1.9.0
- from mindcv.utils.all_finite import all_finite
- #TODO: merge it above
- if loss_scaler is None:
- from mindspore.amp import StaticLossScaler
- loss_scaler = StaticLossScaler(sens)
-
- # Def train func
- #TODO: need to set loss cell to amp?
- ms.amp.auto_mixed_precision(loss_fn, amp_level=amp_level)
-
- if grad_reducer is None:
- grad_reducer = ops.functional.identity
-
- def forward_func(x, label):
- logits = model(x)
- loss = loss_fn(logits, label)
- # TODO: scale not used because of sense?
- loss = loss_scaler.scale(loss)
- return loss, logits
-
- # TODO: why not use value_and_grad? Use rewritten value_and_grad for now.
- #grad_fn = ops.GradOperation(get_by_list=True, sens_param=True)(forward_func, optimizer.parameters)
- #sens_value = sens
-
- grad_fn = value_and_grad(forward_func, None, optimizer.parameters)
-
- @ms.ms_function
- def train_step(x, label):
- loss = forward_func(x, label)
- #sens1 = ops.fill(loss.dtype, loss.shape, sens_value)
- #grads = grad_fn(x, label, sizes, sens1)
- (loss, logits), grads = grad_fn(x, label)
- grads = grad_reducer(grads)
- grads = loss_scaler.unscale(grads)
- loss = loss_scaler.unscale(loss)
- grads_finite = all_finite(grads)
-
- if grads_finite:
- loss = ops.depend(loss, optimizer(grads))
- else:
- if overflow_still_update:
- loss = ops.depend(loss, optimizer(grads))
- print("grad overflow, still update.")
- else:
- print("grad overflow, drop the step.")
-
- return loss, grads, grads_finite
-
- return train_step
-
-
- if __name__ == '__main__':
- args = parse_args()
-
- ms_mode = context.GRAPH_MODE if args.mode == 0 else context.PYNATIVE_MODE
- context.set_context(mode=ms_mode, device_target=args.device_target)
-
- if args.device_target == "Ascend":
- device_id = int(os.getenv('DEVICE_ID', 0))
- context.set_context(device_id=device_id)
- else:
- #raise NotImplementedError
- device_id = None
-
- # Distribute Train
- rank, rank_size, parallel_mode = 0, 1, ParallelMode.STAND_ALONE
- if args.distribute:
- init()
- rank, rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
- context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=rank_size)
-
- args.rank_size = rank_size
- args.rank = rank
-
- args.enable_modelarts = True
-
- # Train
- train(args)
|