|
- import os
- import argparse
-
- from mindspore import context
- from mindspore.context import ParallelMode
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.communication.management import init, get_rank
- import time
- from easydict import EasyDict as edict
- from upload import UploadOutput
-
- from mindspore import nn
- from mindspore.train.model import Model
-
- from swintransformerv2.lr_schedule import build_lr
- from swintransformerv2.loss import get_loss
- from swintransformerv2.optimizer import build_optim
- from swintransformerv2.trainer import build_wrapper
- from swintransformerv2.datasets.build_dataset import build_dataset
- from swintransformerv2.models import build_eval_engine
- from swintransformerv2.models.swin_transformer_v2 import build_swin_v2
- from swintransformerv2.parallel_config import build_parallel_config
- from swintransformerv2.monitor import build_finetune_callback_2
- from config import Config, ActionDict
- import logging
-
- def str2bool(b):
- if b.lower() in ["false"]:
- output = False
- elif b.lower() in ["true"]:
- output = True
- else:
- raise Exception("Invalid Bool Value")
- return output
-
- def main(args):
-
- cfg = edict({
- 'save_checkpoint_steps': 1251,
- 'keep_checkpoint_max': 600
- })
-
- config_ck = CheckpointConfig(
- save_checkpoint_steps=cfg.save_checkpoint_steps,
- keep_checkpoint_max=cfg.keep_checkpoint_max)
-
- device_num = int(os.getenv('RANK_SIZE'))
-
-
-
- if device_num == 1:
- outputDirectory = train_dir + "/"
- if device_num > 1:
- outputDirectory = train_dir + "/" + str(get_rank()) + "/"
- ckpoint_cb = ModelCheckpoint(prefix="checkpoint_swinv2",
- directory=outputDirectory,
- config=config_ck)
-
- uploadOutput = UploadOutput(train_dir,args.train_url)
-
- args.logger = logging.getLogger('train')
- # train dataset
- args.logger.info(".........Build Training Dataset..........")
- train_dataset = build_dataset(args, is_pretrain=False)
- data_size = train_dataset.get_dataset_size()
- time_cb = TimeMonitor(data_size=data_size)
- new_epochs = args.train_config.epoch
- if args.train_config.per_epoch_size and args.train_config.sink_mode:
- new_epochs = int((data_size / args.train_config.per_epoch_size) * new_epochs)
- else:
- args.train_config.per_epoch_size = data_size
- args.data_size = data_size
- args.logger.info("Will be Training epochs:{}, sink_size:{}".format(
- new_epochs, args.train_config.per_epoch_size))
- args.logger.info("Create training dataset finish, data size:{}".format(data_size))
-
- # evaluation dataset
- args.logger.info(".........Build Eval Dataset..........")
- eval_dataset = build_dataset(args, is_pretrain=False, is_train=False)
-
- # build context config
- args.logger.info(".........Build context config..........")
- build_parallel_config(args)
- args.logger.info("context config is:{}".format(args.parallel_config))
- args.logger.info("moe config is:{}".format(args.moe_config))
-
- # build net
- args.logger.info(".........Build Net..........")
- net = build_swin_v2(args)
- eval_engine = build_eval_engine(net, eval_dataset, args)
-
- # build lr
- args.logger.info(".........Build LR Schedule..........")
- lr_schedule = build_lr(args)
- args.logger.info("LR Schedule is: {}".format(args.lr_schedule))
-
- # define optimizer
- # layer-wise lr decay
- args.logger.info(".........Build Optimizer..........")
- optimizer = build_optim(args, net, lr_schedule, args.logger, is_pretrain=False)
-
- # define loss
- args.logger.info(".........Build Loss function..........")
- finetune_loss = get_loss(args)
- # Build train network
- net_with_loss = nn.WithLossCell(net, finetune_loss)
- net_with_train = build_wrapper(args, net_with_loss, optimizer, log=args.logger)
-
- # define Model and begin training
- args.logger.info(".........Starting Init Train Model..........")
- model = Model(net_with_train, metrics=eval_engine.metric, eval_network=eval_engine.eval_network) #
-
- args.logger.info(".........Starting Init Eval Model..........")
- eval_engine.set_model(model)
- # equal to model._init(dataset, sink_size=per_step_size)
- eval_engine.compile(sink_size=args.train_config.per_epoch_size)
-
- LM = LossMonitor()
-
- callback = build_finetune_callback_2(args, eval_engine, time_cb, ckpoint_cb, LM, uploadOutput)
- # if args.profile:
- # callback.append(profile_cb)
-
- args.logger.info(".........Starting Training Model..........")
- model.train(new_epochs, train_dataset, callbacks=callback,
- dataset_sink_mode=args.train_config.sink_mode,
- sink_size=args.train_config.per_epoch_size)
-
-
- if __name__ == "__main__":
-
-
-
- work_path = os.path.dirname(os.path.abspath(__file__))
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default= '/data_url')
-
- parser.add_argument('--train_url',
- help='output folder to save/load',
- default= '/output_url')
-
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
-
- parser.add_argument('--device_num', default=None, type=int, help='device num')
-
-
- parser.add_argument(
- '--config',
- default=os.path.join(work_path, "config path"),
- help='YAML config files')
- parser.add_argument('--device_id', default=None, type=int, help='device id')
- parser.add_argument('--seed', default=None, type=int, help='random seed')
- parser.add_argument('--use_parallel', default=None, type=str2bool, help='whether use parallel mode')
- parser.add_argument('--profile', default=None, type=str2bool, help='whether use profile analysis')
- parser.add_argument('--finetune_path', default=None, type=str, help='checkpoint path for finetune')
- parser.add_argument(
- '--options',
- nargs='+',
- action=ActionDict,
- help='override some settings in the used config, the key-value pair'
- 'in xxx=yyy format will be merged into config file')
-
-
-
- args_ = parser.parse_args()
- config = Config(os.path.join(work_path, args_.config))
-
-
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- config.train_url = args_.train_url
-
-
- if args_.device_num is not None:
- config.device_num = args_.device_num
- if args_.device_id is not None:
- config.context.device_id = args_.device_id
- if args_.seed is not None:
- config.seed = args_.seed
- if args_.use_parallel is not None:
- config.use_parallel = args_.use_parallel
- if args_.profile is not None:
- config.profile = args_.profile
- if args_.finetune_path is not None:
- config.train_config.resume_ckpt = args_.finetune_path
- if args_.options is not None:
- config.merge_from_dict(args_.options)
- if config.finetune_dataset.eval_offset < 0:
- config.finetune_dataset.eval_offset = config.train_config.epoch % config.finetune_dataset.eval_interval
-
- if config.enable_modelarts:
- if not os.path.exists(data_dir):
- os.makedirs(data_dir, exist_ok=True)
- print(f'successfully os.makedirs {data_dir}')
- if not os.path.exists(train_dir):
- os.makedirs(train_dir, exist_ok=True)
- print(f'successfully os.makedirs {train_dir}')
- import moxing as mox
-
- device_num = int(os.getenv('RANK_SIZE'))
- print(f'DownloadFromQizhi device_num = {device_num}')
-
- if device_num == 1:
- try:
- mox.file.copy_parallel(src_url= os.path.join(args_.data_url, "imagenet") , dst_url= data_dir)
- print("Successfully Download {} to {}".format(args_.data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(args_.data_url, data_dir) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
-
- context.set_context(mode=context.GRAPH_MODE,device_target=args_.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args_.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- print(f'local_rank = {local_rank}, start download.')
- try:
- mox.file.copy_parallel(src_url= os.path.join(args_.data_url, "imagenet") , dst_url= data_dir)
- print("Successfully Download {} to {}".format(args_.data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(args_.data_url, data_dir) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
-
- main(config)
|