|
- """Use the PYNATIVE mode to train the network"""
- import os
- import logging
- import time
- import numpy as np
- from tqdm import tqdm
-
- import mindspore as ms
- from mindspore import nn, Tensor, ops, SummaryRecord
- #from mindspore.communication import init, get_rank, get_group_size
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
- #from mindcv.utils.amp import NoLossScaler, DynamicLossScaler, StaticLossScaler, auto_mixed_precision, all_finite
-
- from mindcv.models import create_model
- from mindcv.data import create_dataset, create_transforms, create_loader
- from mindcv.loss import create_loss
- from mindcv.optim import create_optimizer
- from mindcv.scheduler import create_scheduler
- from mindcv.utils import CheckpointManager, Allreduce
- from config import parse_args
-
- #modelarts
- import moxing as mox
- import json
-
-
- # Copy single dataset from obs to training image###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- # Set a cache file to determine whether the data has been copied to obs.
- # If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
-
- # Copy the output to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir, obs_train_url) + str(e))
- return
-
-
- def DownloadFromQizhi(obs_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank = 0
- if device_num == 1:
- ObsToEnv(obs_data_url, data_dir)
- ms.set_context(mode=ms.GRAPH_MODE, device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- ms.set_context(mode=ms.GRAPH_MODE, device_target=args.device_target,
- device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(device_num=device_num, parallel_mode='data_parallel', gradients_mean=True,
- parameter_broadcast=True)
- init()
- # Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank = int(os.getenv('RANK_ID'))
- if local_rank % 8 == 0:
- ObsToEnv(obs_data_url, data_dir)
- # If the cache file does not exist, it means that the copy data has not been completed,
- # and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return device_num, local_rank
-
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank = int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank % 8 == 0:
- EnvToObs(train_dir, obs_train_url)
- return
-
- def train(args):
- """Train network."""
- SEED = 1
- ms.set_seed(SEED)
- np.random.seed(SEED)
-
- ms.set_context(mode=args.mode)
- #ms.set_context(mode=ms.PYNATIVE_MODE, device_target=args.device_target)
-
- if args.enable_modearts:
- device_num = args.device_num
- rank_id = args.rank_id
- else:
- if args.distribute:
- init()
- device_num = get_group_size()
- rank_id = get_rank()
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(device_num=device_num,
- parallel_mode='data_parallel',
- gradients_mean=True)
- else:
- device_num = None
- rank_id = None
-
- # create dataset
- dataset_train = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.train_split,
- shuffle=args.shuffle,
- num_samples=args.num_samples,
- num_shards=device_num,
- shard_id=rank_id,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download,
- num_aug_repeats=args.aug_repeats)
-
- if args.num_classes is None:
- num_classes = dataset_train.num_classes()
- else:
- num_classes = args.num_classes
-
- # create transforms
- transform_list = create_transforms(
- dataset_name=args.dataset,
- is_training=True,
- image_resize=args.image_resize,
- scale=args.scale,
- ratio=args.ratio,
- hflip=args.hflip,
- vflip=args.vflip,
- color_jitter=args.color_jitter,
- interpolation=args.interpolation,
- auto_augment=args.auto_augment,
- mean=args.mean,
- std=args.std,
- re_prob=args.re_prob,
- re_scale=args.re_scale,
- re_ratio=args.re_ratio,
- re_value=args.re_value,
- re_max_attempts=args.re_max_attempts
- )
-
- # load dataset
- loader_train = create_loader(
- dataset=dataset_train,
- batch_size=args.batch_size,
- drop_remainder=False,
- is_training=True,
- mixup=args.mixup,
- cutmix=args.cutmix,
- cutmix_prob=args.cutmix_prob,
- num_classes=num_classes,
- transform=transform_list,
- num_parallel_workers=args.num_parallel_workers,
- )
-
- if args.val_while_train and rank_id in [None, 0]:
- dataset_eval = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.val_split,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download)
-
- transform_list_eval = create_transforms(
- dataset_name=args.dataset,
- is_training=False,
- image_resize=args.image_resize,
- crop_pct=args.crop_pct,
- interpolation=args.interpolation,
- mean=args.mean,
- std=args.std
- )
-
- loader_eval = create_loader(
- dataset=dataset_eval,
- batch_size=args.batch_size,
- drop_remainder=False,
- is_training=False,
- transform=transform_list_eval,
- num_parallel_workers=args.num_parallel_workers,
- )
-
- num_batches = loader_train.get_dataset_size()
- # Train dataset count
- train_count = dataset_train.get_dataset_size()
- if args.distribute:
- all_reduce = Allreduce()
- train_count = all_reduce(Tensor(train_count, ms.int32))
-
- # create model
- network = create_model(model_name=args.model,
- num_classes=num_classes,
- in_channels=args.in_channels,
- drop_rate=args.drop_rate,
- drop_path_rate=args.drop_path_rate,
- pretrained=args.pretrained,
- checkpoint_path=args.ckpt_path)
-
- num_params = sum([param.size for param in network.get_parameters()])
-
- # create loss
- ms.amp.auto_mixed_precision(network, amp_level=args.amp_level)
- # TODO: auto_mixed_precision is changed in MS 1.9.1. used customed auto_mixed_precision to support customized blacklist
- #from mindcv.utils.amp import auto_mixed_precision
- #auto_mixed_precision(network, amp_level=args.amp_level)
-
- loss_fn = create_loss(name=args.loss,
- reduction=args.reduction,
- label_smoothing=args.label_smoothing,
- aux_factor=args.aux_factor)
-
- # create learning rate schedule
- lr_scheduler = create_scheduler(num_batches,
- scheduler=args.scheduler,
- lr=args.lr,
- min_lr=args.min_lr,
- warmup_epochs=args.warmup_epochs,
- warmup_factor=args.warmup_factor,
- decay_epochs=args.decay_epochs,
- decay_rate=args.decay_rate,
- milestones=args.multi_step_decay_milestones,
- num_epochs=args.epoch_size)
-
- # resume training if ckpt_path is given
- if args.ckpt_path != '' and args.resume_opt:
- opt_ckpt_path = os.path.join(args.ckpt_save_dir, f'optim_{args.model}.ckpt')
- else:
- opt_ckpt_path = ''
-
- # create optimizer
- optimizer = create_optimizer(network.trainable_params(),
- opt=args.opt,
- lr=lr_scheduler,
- weight_decay=args.weight_decay,
- momentum=args.momentum,
- nesterov=args.use_nesterov,
- filter_bias_and_bn=args.filter_bias_and_bn,
- loss_scale=args.loss_scale,
- checkpoint_path=opt_ckpt_path)
-
- # set loss scale for mixed precision training
- if args.amp_level != 'O0':
- if args.dynamic_loss_scale:
- from mindcv.utils.amp import DynamicLossScaler
- loss_scaler = DynamicLossScaler(args.loss_scale, 2, 1000)
- else:
- # Fixs bugs in MS 1.8.1 (missing adjust)
- from mindcv.utils.amp import StaticLossScaler
- loss_scaler = StaticLossScaler(args.loss_scale)
- else:
- from mindcv.utils.amp import NoLossScaler
- loss_scaler = NoLossScaler()
- #from mindspore.amp import StaticLossScaler
- #loss_scaler = StaticLossScaler(args.loss_scale)
-
- # resume
- begin_step = 0
- begin_epoch = 0
- if args.ckpt_path != '':
- begin_step = optimizer.global_step.asnumpy()[0]
- begin_epoch = args.ckpt_path.split('/')[-1].split('_')[0].split('-')[-1]
- begin_epoch = int(begin_epoch)
-
- # log
- if rank_id in [None, 0]:
-
- print(f"-" * 40)
- print(f"Num devices: {device_num if device_num is not None else 1} \n"
- f"Distributed mode: {args.distribute} \n"
- f"Num training samples: {train_count}")
- print(f"Num classes: {num_classes} \n"
- f"Num batches: {num_batches} \n"
- f"Batch size: {args.batch_size} \n"
- f"Auto augment: {args.auto_augment} \n"
- f"Model: {args.model} \n"
- f"Model param: {num_params} \n"
- f"Num epochs: {args.epoch_size} \n"
- f"Optimizer: {args.opt} \n"
- f"LR: {args.lr} \n"
- f"LR Scheduler: {args.scheduler}")
- print(f"-" * 40)
-
- if args.ckpt_path != '':
- print(f"Resume training from {args.ckpt_path}, last step: {begin_step}, last epoch: {begin_epoch}")
- else:
- print('Start training')
-
- if not os.path.exists(args.ckpt_save_dir):
- os.makedirs(args.ckpt_save_dir)
-
- log_path = os.path.join(args.ckpt_save_dir, 'result.log')
- if not (os.path.exists(log_path) and args.ckpt_path != ''): # if not resume training
- with open(log_path, 'w') as fp:
- fp.write('Epoch\tTrainLoss\tValAcc\tTime\n')
-
- best_acc = 0
- summary_dir = f"./{args.ckpt_save_dir}/summary_01"
-
- # Training
- need_flush_from_cache = True
- assert (args.ckpt_save_policy != 'top_k' or args.val_while_train == True), \
- "ckpt_save_policy is top_k, val_while_train must be True."
- manager = CheckpointManager(ckpt_save_policy=args.ckpt_save_policy)
-
- if args.log_interval is None:
- log_interval = num_batches
- else:
- log_interval = args.log_interval
-
- # build train step
-
- from mindcv.utils.all_finite import all_finite
-
- # Define forward function
- def forward_fn(data, label):
- logits = network(data)
- loss = loss_fn(logits, label)
- loss = loss_scaler.scale(loss)
- return loss, logits
-
- # Get gradient function
- grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
-
- if args.distribute:
- mean = _get_gradients_mean()
- degree = _get_device_num()
- grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
- # Define function of one-step training
- @ms.ms_function
- def train_step_parallel(data, label):
- (loss, logits), grads = grad_fn(data, label)
- grads = grad_reducer(grads)
- status = all_finite(grads)
- if status:
- loss = loss_scaler.unscale(loss)
- grads = loss_scaler.unscale(grads)
- loss = ops.depend(loss, optimizer(grads))
-
- # TODO: adjust is well supported in MS 1.9.1, used in dynamic loss scaler.
- loss = ops.depend(loss, loss_scaler.adjust(status))
-
- return loss, logits
-
- @ms.ms_function
- def train_step(data, label):
- (loss, logits), grads = grad_fn(data, label)
- loss = ops.depend(loss, optimizer(grads))
- status = all_finite(grads)
- if status:
- loss = loss_scaler.unscale(loss)
- grads = loss_scaler.unscale(grads)
- loss = ops.depend(loss, optimizer(grads))
- loss = ops.depend(loss, loss_scaler.adjust(status))
-
- return loss, logits
-
- for t in range(begin_epoch, args.epoch_size):
- epoch = t
- epoch_start = time.time()
- network.set_train()
-
- total, correct = 0, 0
-
- start = time.time()
-
- for batch, (data, label) in enumerate(loader_train.create_tuple_iterator()):
- if args.distribute:
- loss, logits = train_step_parallel(data, label)
- else:
- loss, logits = train_step(data, label)
-
- if len(label.shape) == 1:
- correct += (logits.argmax(1) == label).asnumpy().sum()
- else: #one-hot or soft label
- correct += (logits.argmax(1) == label.argmax(1)).asnumpy().sum()
- total += len(data)
-
- if (batch + 1) % log_interval == 0 or (batch + 1) >= num_batches or batch==0:
- step = args.epoch_size * num_batches + batch
- if optimizer.dynamic_lr:
- cur_lr = optimizer.learning_rate(Tensor(step)).asnumpy()
- else:
- cur_lr = optimizer.learning_rate.asnumpy()
- print(f"Epoch:[{epoch+1}/{args.epoch_size}], "
- f"batch:[{batch+1}/{num_batches}], "
- f"loss:{loss.asnumpy():.6f}, lr: {cur_lr:.7f}, time:{time.time() - start:.6f}s")
- start = time.time()
-
- # val while train
- test_acc = Tensor(-1.0)
- if args.val_while_train and rank_id in [0, None]:
- if ((t + 1) % args.val_interval == 0) or (t + 1 == args.epoch_size):
- if rank_id in [None, 0]:
- print('Validating...')
- val_start = time.time()
-
- network.set_train(False) # TODO: check freeze
-
- correct, total = 0, 0
- for data, label in loader_eval.create_tuple_iterator():
- pred = network(data)
- total += len(data)
- if len(label.shape) == 1:
- correct += (pred.argmax(1) == label).asnumpy().sum()
- else: #one-hot or soft label
- correct += (pred.argmax(1) == label.argmax(1)).asnumpy().sum()
-
- test_acc = 100 * correct / total
- val_time = time.time() - val_start
- print(f"Val time: {val_time:.2f} \t Val acc: {test_acc:0.3f}")
- if test_acc > best_acc:
- best_acc = test_acc
- save_best_path = os.path.join(args.ckpt_save_dir, f"{args.model}-best.ckpt")
- ms.save_checkpoint(network, save_best_path, async_save=True)
- print(f"=> New best val acc: {test_acc:0.3f}")
-
- # Save checkpoint
- if rank_id in [0, None]:
- if ((t + 1) % args.ckpt_save_interval == 0) or (t + 1 == args.epoch_size):
- if need_flush_from_cache:
- need_flush_from_cache = flush_from_cache(network)
-
- ms.save_checkpoint(optimizer, os.path.join(args.ckpt_save_dir, f'{args.model}_optim.ckpt'),
- async_save=True)
- save_path = os.path.join(args.ckpt_save_dir, f"{args.model}-{t + 1}_{num_batches}.ckpt")
- ckpoint_filelist = manager.save_ckpoint(network, num_ckpt=args.keep_checkpoint_max,
- metric=test_acc, save_path=save_path)
- if args.ckpt_save_policy == 'top_k':
- checkpoints_str = "Top K accuracy checkpoints: \n"
- for ch in ckpoint_filelist:
- checkpoints_str += '{}\n'.format(ch)
- print(checkpoints_str)
- else:
- print(f"Saving model to {save_path}")
-
- epoch_time = time.time() - epoch_start
- print(f'Epoch {t + 1} time:{epoch_time:.3f}s')
- with open(log_path, 'a') as fp:
- fp.write(f'{t+1}\t{loss.asnumpy():.7f}\t{test_acc.asnumpy():.3f}\t{epoch_time:.2f}\n')
-
- print("Done!")
-
- def flush_from_cache(network):
- """Flush cache data to host if tensor is cache enable."""
- has_cache_params = False
- params = network.get_parameters()
- for param in params:
- if param.cache_enable:
- has_cache_params = True
- Tensor(param).flush_from_cache()
- if not has_cache_params:
- need_flush_from_cache = False
- else:
- need_flush_from_cache = True
- return need_flush_from_cache
-
-
- if __name__ == '__main__':
- args = parse_args()
-
- args.enable_modearts = True
-
- if args.enable_modearts:
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir, exist_ok=True)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir, exist_ok=True)
-
- # Initialize and copy data to training image
- # DownloadFromQizhi is much slower than sync_data but sync_data usually abort;
- device_num, local_rank = DownloadFromQizhi(args.data_url, data_dir)
- # data_url = args.data_url
- # local_data_path = '/cache/dataset'
- # os.makedirs(local_data_path, exist_ok=True)
- # sync_data(data_url, local_data_path, threads=256)
- print(f"data_dir:{os.listdir(data_dir)}")
- if "imagenet" in os.listdir(data_dir):
- data_dir = os.path.join(data_dir, "imagenet")
- args.data_dir = data_dir
-
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- args.ckpt_save_dir = train_dir + "/"
- if device_num > 1:
- #args.ckpt_save_dir = train_dir + "/" + str(get_rank()) + "/"
- args.ckpt_save_dir = train_dir + "/" + os.getenv('RANK_ID') + "/"
- # profiler = ms.Profiler(output_path='./1.5x_profiler_data')
-
- args.device_num = device_num
- args.rank_id = local_rank
-
- train(args)
-
- if args.enable_modearts:
- # profiler.analyse()
- # UploadToQizhi('./1.5x_profiler_data', args.train_url)
- UploadToQizhi(args.ckpt_save_dir, args.train_url)
|