|
- # third party packages
- import argparse
- import datetime
- import numpy as np
- import time
- import pdb
- import json
- import os
- from pathlib import Path
- import threading
- import glob
- import re
-
- from conformer import Conformer
- from conformer_overflow import ConformerOverflow
-
- import mindspore
- import mindspore.communication as Comm
- from mindspore import Model, context, Tensor
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.nn.optim import AdamWeightDecay
- from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- from mindspore.parallel._cost_model_context import _set_algo_single_loop
- from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
- from mindspore.profiler import Profiler
- import mindspore.nn as nn
- from misc.copy_checkpoint import CopyCheckpoint
- from misc.eval_summary import EvalCallBack, MeanAcc
- from misc.accuracy import Accuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
- from misc.eval_model import WithEvalCell
- from misc.learning_rate import LearningRate
- from misc.loss_callback import LossMonitor
- from misc.optimizer import FP32StateAdamWeightDecay
- from misc.CrossEntropySmooth import CrossEntropySmooth
- from misc.utils_copy import download_data
- from misc.clip_grad import TrainOneStepWithLossScaleCellAndClip, NetWithLossCell, DynamicLossScaleUpdateCell
-
- try:
- import moxing as mox
- ###obs
- os.environ.pop('CREDENTIAL_PROFILES_FILE', None)
- os.environ.pop('AWS_SHARED_CREDENTIALS_FILE', None)
- mox.file.set_auth(ak='MLYPTVVTA4PKWUQQ3LSC', sk='IiKSDc1izof73vAZhj3wuMXXbsVMtKQT6zu0F1Wp', server='https://obs.cn-south-222.ai.pcl.cn')
- from moxing.framework.file import file_io
- print('[ModelArts] Fixing Download')
-
- file_io._create_or_get_obs_client()
- file_io._LARGE_FILE_METHOD = 1
-
- from moxing.framework.file.file_io import math
- from moxing.framework.file.file_io import _do_download_part
- def _download_obs_with_large_file(bucket_name, object_key, file_path, object_size):
- print('[ModelArts] Monkey Patching')
- part_size = 10 * 1024 * 1024
- part_count = int(round(math.ceil(object_size / float(part_size))))
- # print('download file with part-size=%s, part-cound=%s' % (part_size, part_count))
-
- with open(file_path, 'wb'):
- pass
-
- import concurrent.futures
- futures = []
- completed_blocks = [0] * part_count
- with concurrent.futures.ProcessPoolExecutor(16) as executor:
- for i in range(part_count):
- start_pos = i * part_size
- end_pos = object_size - 1 if (i + 1) == part_count else ((i + 1) * part_size - 1)
- # print(i, start_pos, end_pos)
- future = executor.submit(_do_download_part, completed_blocks, bucket_name, object_key, file_path,
- start_pos, end_pos, i)
- futures.append(future)
-
- for future in concurrent.futures.as_completed(futures):
- future.result()
-
- setattr(file_io, '_download_obs_with_large_file', _download_obs_with_large_file)
-
- training_cloud = True
- except:
- training_cloud = False
-
- if training_cloud:
- from datasets import build_dataset, classification_dataset
- from misc.loss_summary import LossSummaryCallback
- else:
- from outdated.datasets_local import build_dataset, classification_dataset
-
- mindspore.set_seed(0)
- np.random.seed(0)
-
- def get_args_parser():
- parser = argparse.ArgumentParser('Training and evaluation script', add_help=False)
- parser.add_argument('--batch-size', default=64, type=int)
- parser.add_argument('--epochs', default=300, type=int)
- parser.add_argument('--sync-bn', action='store_true', default=False, help='Enable sync batchnorm')
-
- # Model parameters
- parser.add_argument('--model', default='Conformer_base_patch16', type=str, metavar='MODEL',
- help='Name of model to train')
- parser.add_argument('--input-size', default=224, type=int, help='images input size')
-
- parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
- help='Dropout rate (default: 0.)')
- parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
- help='Drop path rate (default: 0.1)')
- parser.add_argument('--drop-block', type=float, default=0.0, metavar='PCT',
- help='Drop block rate (default: None)')
- parser.add_argument('--norm-type', default='DEFAULT', type=str, help='Norm in datasets transform')
-
- parser.add_argument('--model-ema', action='store_true')
- parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
- parser.set_defaults(model_ema=True)
- parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
- parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
-
- # Optimizer parameters
- parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
- help='Optimizer (default: "adamw"')
- parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
- help='Optimizer Epsilon (default: 1e-8)')
- parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
- help='Optimizer Betas (default: None, use opt default)')
- parser.add_argument('--clip-grad', type=float, default=1.0, metavar='NORM',
- help='Clip gradient norm (default: None, no clipping)')
- parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
- help='SGD momentum (default: 0.9)')
- parser.add_argument('--weight-decay', type=float, default=0.05,
- help='weight decay (default: 0.05)')
- # Learning rate schedule parameters
- parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
- help='LR scheduler (default: "cosine"')
- parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
- help='learning rate (default: 5e-4)')
- parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
- help='learning rate noise on/off epoch percentages')
- parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
- help='learning rate noise limit percent (default: 0.67)')
- parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
- help='learning rate noise std-dev (default: 1.0)')
- parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
- help='warmup learning rate (default: 1e-6)')
- parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
- help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
-
- parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
- help='epoch interval to decay LR')
- parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
- help='epochs to warmup LR, if scheduler supports')
- parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
- help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
- parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
- help='patience epochs for Plateau LR scheduler (default: 10')
- parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
- help='LR decay rate (default: 0.1)')
-
- # Augmentation parameters
- parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
- help='Color jitter factor (default: 0.4)')
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
- help='Use AutoAugment policy. "v0" or "original". " + \
- "(default: rand-m9-mstd0.5-inc1)'),
- parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
- parser.add_argument('--train-interpolation', type=str, default='bicubic',
- help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
-
- parser.add_argument('--repeated-aug', action='store_true')
- parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
- parser.set_defaults(repeated_aug=True)
-
- # * Random Erase params
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
- help='Random erase prob (default: 0.25)')
- parser.add_argument('--remode', type=str, default='pixel',
- help='Random erase mode (default: "pixel")')
- parser.add_argument('--recount', type=int, default=1,
- help='Random erase count (default: 1)')
- parser.add_argument('--resplit', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
-
- # * Mixup params
- parser.add_argument('--mixup', type=float, default=0.8,
- help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
- parser.add_argument('--cutmix', type=float, default=1.0,
- help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
- parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
- parser.add_argument('--mixup-prob', type=float, default=1.0,
- help='Probability of performing mixup or cutmix when either/both is enabled')
- parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
- parser.add_argument('--mixup-mode', type=str, default='batch',
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
-
- # Dataset parameters
- parser.add_argument('--data_url', default='s3://yuhy/imagenet/', type=str,
- # obs://yuhy/imagenet/ s3://cvmodel/imagenet/
- help='dataset obs url')
- parser.add_argument('--data-path', default='/cache/data/', type=str,
- help='dataset path')
- parser.add_argument('--data-set', default='IMNET', type=str, help='Image Net dataset path')
- parser.add_argument('--data_level', default='1kw', type=str, help='data quality level')
- parser.add_argument('--data_source_file', default='s3://cvmodel/mindrecord_dataset/imagenet1k/',
- type=str, help='dataset obs data_source_file')
- parser.add_argument('--data_file', default='/cache/data/', type=str,
- help='dataset local data_file')
- parser.add_argument('--inat-category', default='name',
- choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
- type=str, help='semantic granularity')
- # * Finetuning params
- parser.add_argument('--finetune', default='', help='finetune from checkpoint')
- parser.add_argument('--evaluate-freq', type=int, default=2, help='frequency of perform evaluation (default: 1)')
- parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
- parser.add_argument('--output_dir', default='',
- help='path where to save, empty for no saving')
- parser.add_argument('--device', default='Ascend',
- help='device to use for training / testing, choose from gpu and Ascend')
- parser.add_argument('--seed', default=0, type=int)
- parser.add_argument('--resume', default='', help='resume from checkpoint')
- parser.add_argument('--checkpoint-conformer', default='', help='checkpoint of conformer')
- parser.add_argument('--checkpoint-masktrans', default='', help='checkpoint of masktrans')
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
- help='start epoch')
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
- parser.add_argument('--num_workers', default=10, type=int)
- parser.add_argument('--pin-mem', action='store_true',
- help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
- parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
- help='')
- parser.set_defaults(pin_mem=True)
-
- # save checkpoint
- parser.add_argument('--save-checkpoint', action='store_true', default=False,
- help='save checkpoint for model or not')
- parser.add_argument('--save-checkpoint-epochs', type=int, default=50,
- help='save checkpoint per epochs')
- parser.add_argument('--keep-checkpoint-max', type=int, default=3,
- help='how many checkpoints will be saved')
- parser.add_argument('--checkpoint-path', type=str, default='/cache/checkpoint/',
- help='where to save checkpoints')
- parser.add_argument('--run-distribute', action='store_true', default=False)
-
- # transform the dataset
- parser.add_argument('--transform', action='store_true', default=False,
- help='transform the dataset with augment, mixup and cutmix or not')
-
- # profile the training performance
- parser.add_argument('--profile-path', type=str, default='',
- help='dir for saving profiling files')
-
- # copy checkpoint to obs
- parser.add_argument('--copy-ckpt', action='store_true', help='copy checkpoint to obs')
-
- # training url
- parser.add_argument('--train_url', type=str, default='s3://youhui/output/',
- help='train url')
-
- # set seq length for image size
- parser.add_argument('--seq-length', type=int, default=196,
- help='sequence length for image size, 224 --> 196, 384 --> 576')
-
- # restart training from breakpoint
- parser.add_argument('--breakpoint', default='',
- help='training from breakpoint or not')
-
- return parser
-
- def set_save_ckpt_dir(args):
- ckpt_save_dir = args.checkpoint_path
- if not os.path.exists(ckpt_save_dir):
- os.makedirs(ckpt_save_dir, exist_ok=True)
- if args.run_distribute:
- ckpt_save_dir = ckpt_save_dir + 'ckpt_' + str(Comm.get_rank()) + "/"
- return ckpt_save_dir
-
- def copy_ckpt_apply(eval_param, rank_id=None, src_url=None, dst_url=None):
- if training_cloud:
- mox.file.copy_parallel(src_url=src_url, dst_url=dst_url)
- return rank_id
-
- def copy_checkpoint(args, device_num, rank_id, model, src_url=None, dst_url=None):
- # eval_dataset, nb_classes = build_dataset(is_train=False, args=args, device_num=device_num, rank_id=rank_id)
- # eval_dataset = classification_dataset(eval_dataset, [args.input_size]*2, args.batch_size, mode='eval')
-
- # mean_acc = MeanAcc(device_num, args.device)
- mean_acc = None
- eval_param_dict = {"model" : model, "metrics_name" : ["acc", "top1_acc", "top5_acc"]}
- eval_cb = CopyCheckpoint(copy_ckpt_apply, eval_param_dict, save_best_ckpt=False, mean_acc=mean_acc,
- rank_id=rank_id, src_url=src_url, dst_url=dst_url, bucket=dst_url+'acc_summary/')
- return eval_cb
-
- def apply_eval(eval_param, rank_id=None, src_url=None, dst_url=None):
- eval_model = eval_param["model"]
- eval_ds = eval_param["dataset"]
- res = eval_model.eval(eval_ds)
- if rank_id % 8 == 0 and training_cloud and src_url and dst_url:
- files = os.listdir(src_url)
- for file in files:
- os.system("cp {}{} {}{}".format(src_url, file, src_url, str(int(rank_id/8))+file))
- mox.file.copy_parallel(src_url=src_url, dst_url=dst_url)
- os.system("rm -f {}{}".format(src_url, str(int(rank_id/8))+file))
- return res
-
- def run_eval(args, device_num, rank_id, model, src_url=None, dst_url=None):
- eval_dataset, nb_classes = build_dataset(is_train=False, args=args, device_num=device_num, rank_id=rank_id)
- eval_dataset = classification_dataset(eval_dataset, [args.input_size]*2, args.batch_size, args=args, mode='eval', num_classes=nb_classes)
-
- mean_acc = MeanAcc(device_num, args.device)
- eval_param_dict = {"model" : model, "dataset" : eval_dataset, "metrics_name" : ["acc", "top1_acc", "top5_acc"]}
- eval_cb = EvalCallBack(apply_eval, eval_param_dict, save_best_ckpt=False, mean_acc=mean_acc,
- rank_id=rank_id, src_url=src_url, dst_url=dst_url, bucket=dst_url+'acc_summary/')
- return eval_cb
-
- def load_pre_trained_checkpoint(args, rank=0):
- """
- Load checkpoint according to pre_trained path.
- """
- # download checkpoint from obs finetune to local cache/checkpoint
- EXEC_PATH = '/tmp'
- if rank % 8 == 0:
- print("begin download checkpoint", flush=True)
-
- if not os.path.exists(args.checkpoint_path):
- os.makedirs(args.checkpoint_path, exist_ok=True)
- mox.file.copy_parallel(src_url=args.finetune,
- dst_url=args.checkpoint_path)
- print("checkpoint download succeed!", flush=True)
-
- f = open("%s/download_ckpt.txt" % (EXEC_PATH), 'w')
- f.close()
- # stop
- while not os.path.exists("%s/download_ckpt.txt" % (EXEC_PATH)):
- time.sleep(1)
-
- while True:
- if not os.path.exists(args.checkpoint_path):
- time.sleep(10)
- else:
- print(print(os.listdir(args.checkpoint_path)))
- break
-
- param_dict = None
- if os.path.isdir(args.checkpoint_path):
- ckpt_pattern = os.path.join(args.checkpoint_path, "*.ckpt")
- ckpt_files = glob.glob(ckpt_pattern)
- if not ckpt_files:
- print(f"There is no ckpt file in {args.checkpoint_path}, "
- f"pre_trained is unsupported.")
- else:
- files_dict = {}
- for ckpt_file in ckpt_files:
- searchobj = re.search(r'.*-(\d+)_', ckpt_file)
- if searchobj:
- files_dict[int(searchobj.group(1))] = ckpt_file
- loading_ckpt_file = files_dict[sorted(files_dict)[-1]]
- time_stamp = datetime.datetime.now()
- print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')}"
- f" pre trained ckpt model {loading_ckpt_file} loading",
- flush=True)
- param_dict = load_checkpoint(loading_ckpt_file)
- elif os.path.isfile(args.checkpoint_path):
- param_dict = load_checkpoint(args.checkpoint_path)
- else:
- print(f"Invalid pre_trained {args.checkpoint_path} parameter.", rank)
- for param in ['trans_cls_head.weight', 'trans_cls_head.bias', 'conv_cls_head.weight', 'conv_cls_head.bias']:
- param_keys = list(param_dict.keys()).copy()
- for name in param_keys:
- if param in name:
- print(name, param_dict[name])
- param_dict.pop(name)
- print("delete the parameter:", name, name in param_dict, rank)
- return param_dict
-
- def load_breakpoint_checkpoint(args, rank=0):
- """
- Load checkpoint according to pre_trained path.
- """
- # download checkpoint from obs finetune to local cache/checkpoint
- EXEC_PATH = '/tmp'
- if rank % 8 == 0:
- print("begin download checkpoint", flush=True)
-
- if not os.path.exists(args.checkpoint_path):
- os.makedirs(args.checkpoint_path, exist_ok=True)
- mox.file.copy_parallel(src_url=args.breakpoint,
- dst_url=args.checkpoint_path)
- print("checkpoint download succeed!", flush=True)
-
- f = open("%s/download_ckpt.txt" % (EXEC_PATH), 'w')
- f.close()
- # stop
- while not os.path.exists("%s/download_ckpt.txt" % (EXEC_PATH)):
- time.sleep(1)
-
- while True:
- if not os.path.exists(args.checkpoint_path):
- time.sleep(10)
- else:
- print(print(os.listdir(args.checkpoint_path)))
- break
-
- param_dict = None
- if os.path.isdir(args.checkpoint_path):
- ckpt_pattern = os.path.join(args.checkpoint_path, "*breakpoint.ckpt")
- ckpt_files = glob.glob(ckpt_pattern)
- for ckpt_file in ckpt_files:
- print("loading checkpoint from:", ckpt_file, flush=True)
- param_dict = load_checkpoint(ckpt_file)
- elif os.path.isfile(args.checkpoint_path):
- param_dict = load_checkpoint(args.checkpoint_path)
- else:
- print(f"Invalid pre_trained {args.checkpoint_path} parameter.", rank)
- return param_dict
-
-
- def convert_sync_batchnorm_old(network):
- cells = network.name_cells()
- change = False
- for name in cells:
- subcell = cells[name]
- if subcell == network:
- continue
- elif isinstance(subcell, nn.BatchNorm2d):
- new_subcell = nn.SyncBatchNorm(subcell.num_features,
- subcell.eps,
- subcell.momentum)
- new_subcell.gamma = subcell.gamma
- new_subcell.beta = subcell.beta
- new_subcell.moving_mean = subcell.moving_mean
- new_subcell.moving_variance = subcell.moving_variance
- network._cells[name] = new_subcell
- change = True
- else:
- convert_sync_batchnorm_old(subcell)
- if isinstance(network, nn.SequentialCell) and change:
- network.cell_list = list(network.cells())
-
-
- def convert_sync_batchnorm(network, groups):
- cells = network.name_cells()
- change = False
- for name in cells:
- subcell = cells[name]
- if subcell == network:
- continue
- elif isinstance(subcell, nn.BatchNorm2d):
- new_subcell = nn.SyncBatchNorm(subcell.num_features,
- subcell.eps,
- subcell.momentum,
- process_groups=groups)
- new_subcell.gamma = subcell.gamma
- new_subcell.beta = subcell.beta
- new_subcell.moving_mean = subcell.moving_mean
- new_subcell.moving_variance = subcell.moving_variance
- network._cells[name] = new_subcell
- change = True
- else:
- convert_sync_batchnorm(subcell, groups)
- if isinstance(network, nn.SequentialCell) and change:
- network.cell_list = list(network.cells())
-
- def set_allreduce_fusion(network, count):
- cells = network.name_cells()
- if len(cells) == 0:
- network.set_comm_fusion(count//60)
- count += 1
- return count
- for name in cells:
- subcell = cells[name]
- if subcell == network:
- continue
- else:
- count = set_allreduce_fusion(subcell, count)
- return count
-
- def build_groups(rank_size):
- row = rank_size // 8
- groups = [[0]*8 for i in range(row)]
- for i in range(row):
- for j in range(8):
- groups[i][j] = i*8 + j
- return groups
-
- def main(args):
- args.transform = True
-
- args.model_size = 'x3'
- args.save_checkpoint_epochs = 5
- args.keep_checkpoint_max = 10
-
- ########## settings of log, data and checkpoint
- # for log
- src_url = '/tmp/log/'
- dst_url = f's3://yuhy/log/{args.model_size}/'
-
- if not os.path.exists(src_url):
- os.makedirs(src_url, exist_ok=True)
-
- # pretrain
- ########## imagenet 21k
- # args.data_set = 'IMNET21k'
- # args.data_level = '5kw'
- args.checkpoint_store_path = f's3://yuhy/checkpoint/{args.model_size}/21k5kw/' # dir to save pretrained checkpoint
-
- # args.resume = args.checkpoint_store_path
- # finetune
- args.data_set = 'IMNET'
- args.finetune = args.checkpoint_store_path # dir of pretrained checkpoint
-
- ## parameter settings for pretrain and finetune ##
- # for debug
- in_debug = False
- if in_debug:
- args.device = 'CPU'
- args.data_path = './Conformer_3_mindspore/data/imagenet_sub20'
-
- if not training_cloud:
- args.batch_size = 32
- args.lr = 0.000046875
- args.input_size = 224
- args.data_set = 'IMNET'
- args.eval = True
- args.transform = True
- args.clip_grad = 1.0
- else:
- if args.data_set == 'IMNET' and args.finetune == '': # only training imagenet1k
- args.batch_size = 64
- args.lr = 4e-3
- args.opt = 'adamw'
- args.sched = 'cosine'
- args.warmup_epochs = 5
- args.weight_decay = 0.05
- args.clip_grad = 1.0
- args.epochs = 300
- args.drop_path = 0.1
- args.reprob = 0.25
- args.mixup = 0.8
- args.cutmix = 1.0
- args.smoothing = 0.1
- args.eval = True
- args.data_file = '/mnt/sfs_turbo/mindrecord_dataset/imagenet1k/'
- elif args.data_set == 'IMNET' and not (args.finetune == ''): # args.finetune={checkpoint_path}
- # basic: batch size 1024, learning rate 5e-5, no data augmentation
- args.batch_size = 4
- args.input_size = 448 # 384
- args.seq_length = 784 # 576 # (384/16)**2
- args.lr = 4e-5 # 5e-5
- args.opt = 'adamw'
- args.sched = 'cosine'
- args.warmup_epochs = 5
- args.weight_decay = 1e-8
- args.model_ema = True # 0.99996
- args.clip_grad = 1.0
- args.epochs = 30
- args.sync_bn = True
- args.drop_path = 0.1
- args.aa = None
- args.transform = False # disable autoaugment, mixup & cutmix and smoothing
- args.repeated_aug = False
- args.reprob = 0.0
- args.mixup = 0.0
- args.cutmix = 0.0
- args.smoothing = 0.1
- # args.smoothing = 1e-4 # revised 2022/3/29
- args.dist_eval = False
- args.eval = True
- #%%
- # args.lr = 2e-5
- # args.drop_path = 0.0
-
- elif args.data_set == 'IMNET21k' and args.data_level == '5kw':
- # basic: batch size 4096, learning rate 1e-3, data augmentation
- args.batch_size = 16
- args.lr = 1e-3
- args.opt = 'adamw'
- args.sched = 'linear'
- args.warmup_epochs = 3
- # args.weight_decay = 0.05
- args.weight_decay = 1e-2
- args.model_ema = False
- args.clip_grad = 1.0
- args.epochs = 30
- args.sync_bn = False
- args.drop_path = 0.1
- args.aa = 'rand-m9-mstd0.5-inc1' # autoaugment default True
- args.repeated_aug = False
- args.reprob = 0.25
- # args.mixup = 0.5
- # args.cutmix = 0.8
- args.mixup = 0.8
- args.cutmix = 1.0
- args.smoothing = 1e-4
- args.dist_eval = False
- args.eval = False
- args.copy_ckpt = True
- args.save_checkpoint = True
- elif args.data_set == 'IMNET21k' and args.data_level == '10T':
- # basic: batch size 4096, learning rate 1e-3, data augmentation
- args.batch_size = 16
- args.lr = 1e-3
- args.opt = 'adamw'
- args.sched = 'linear'
- args.warmup_epochs = 2
- args.weight_decay = 0.03
- args.model_ema = False
- args.clip_grad = 1.0
- args.epochs = 30
- args.sync_bn = True
- args.drop_path = 0.1
- args.aa = 'rand-m9-mstd0.5-inc1' # autoaugment default True
- args.repeated_aug = False
- args.reprob = 0.25
- args.mixup = 0.5
- args.cutmix = 0.8
- args.smoothing = 1e-4
- args.dist_eval = False
- args.eval = False
- args.copy_ckpt = True
- args.save_checkpoint = True
- elif args.data_set == 'IMNET21k':
- # basic: batch size 4096, learning rate 1e-3, data augmentation
- args.batch_size = 16
- args.lr = 1e-3
- args.opt = 'adamw'
- args.sched = 'linear'
- args.warmup_epochs = 5
- args.weight_decay = 1e-2
- args.model_ema = False
- args.clip_grad = 1.0
- args.epochs = 90
- args.sync_bn = True
- args.drop_path = 0.1
- args.aa = 'rand-m9-mstd0.5-inc1' # autoaugment default True: args.transform=True
- args.repeated_aug = False
- args.reprob = 0.25
- args.mixup = 0.8
- args.cutmix = 1.0
- args.smoothing = 1e-4
- args.dist_eval = False
- args.eval = False
- args.copy_ckpt = True
- args.save_checkpoint = True
- elif args.data_set == 'IMNET21k_ext':
- # basic: batch size 4096, learning rate 1e-3, data augmentation
- args.batch_size = 4
- args.lr = 1e-4
- args.opt = 'adamw'
- args.sched = 'linear'
- args.warmup_epochs = 5
- args.weight_decay = 1e-2
- args.model_ema = False
- args.clip_grad = 1.0
- args.epochs = 90
- args.sync_bn = False
- args.drop_path = 0.1
- args.aa = 'rand-m9-mstd0.5-inc1' # autoaugment default True
- args.repeated_aug = False
- args.reprob = 0.25
- args.mixup = 0.8
- args.cutmix = 1.0
- args.smoothing = 1e-4
- args.dist_eval = False
- args.eval = False
- args.data_file = '/mnt/sfs_turbo/mindrecord_dataset/'
- args.save_checkpoint = True
-
- print(args)
-
- # launch the profiling
- if args.profile_path:
- profiler = Profiler(output_path=args.profile_path)
-
- # set the context environment
- if args.device == 'CPU' or args.device == 'GPU':
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device)
- else:
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device, device_id=int(os.environ["DEVICE_ID"]))
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
- if in_debug:
- rank_size = 1
- rank_id = 0
- else:
- if args.device == 'GPU':
- init("nccl")
- else:
- init()
- try:
- if args.device == 'GPU':
- rank_size = Comm.get_group_size()
- rank_id = Comm.get_rank()
- else:
- rank_size = int(os.environ["RANK_SIZE"])
- rank_id = int(os.environ["RANK_ID"])
- except RuntimeError:
- rank_size = 1
- rank_id = 0
-
- # for data in dataset_train.create_dict_iterator():
- # print("processing dataset:", time.time(), flush=True)
- if args.data_set == 'IMNET':
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/imagenet1k/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id)
- elif args.data_set == 'IMNET21k' and args.data_level == '5kw':
- # args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/img21k_ps_5kw/'
- # download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id)
-
- # args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/imagenet21k/'
- # download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='21k')
- # args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/21K_ext3/dot5-1/'
- # download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='dot5-1')
- # args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/21K_ext3/dot5-2/'
- # download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='dot5-2')
-
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/imagenet21k/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='21k')
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/21K_ext3/key00w/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='key00w')
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/21K_ext3/key01w/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='key01w')
- elif args.data_set == 'IMNET21k' and args.data_level == '10T':
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/21K_ext3/dot5/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='10T')
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/imagenet21k/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id, dataset='21k')
- elif args.data_set == 'IMNET21k':
- args.data_source_file = 's3://cvmodel/mindrecord_dataset_new/imagenet21k/'
- download_data(src_data_url=args.data_source_file, tgt_data_path=args.data_path, rank=rank_id)
-
- # print(os.listdir(args.data_path))
-
- # fix the seed for reproducibility
- # seed = args.seed + utils.get_rank()
- # mindspore.set_seed(args.seed)
- # np.random.seed(args.seed)
- # random.seed(seed)
-
- dataset_train, args.nb_classes = build_dataset(is_train=True, args=args,
- device_num=rank_size, rank_id=rank_id)
- dataset_train = classification_dataset(dataset_train, [args.input_size]*2, args.batch_size,
- args=args, num_classes=args.nb_classes, smoothing_rate=args.smoothing, transform=args.transform)
-
-
- # # small model
- # net = Conformer(patch_size=16, channel_ratio=4, embed_dim=384, depth=12,
- # num_heads=6, mlp_ratio=4, qkv_bias=False, cls_token=True,
- # num_classes=args.nb_classes, drop_rate=args.drop,
- # drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- # batch_size=args.batch_size)
-
- if args.model_size == 'xs':
- # small model - 38M
- net = ConformerOverflow(patch_size=16, channel_ratio=4, embed_dim=384, stage_point=[1, 4, 8, 12],
- num_heads=6, mlp_ratio=4, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True, seq_length=args.seq_length)
- elif args.model_size == 'x0':
- # base model - 87M
- net = ConformerOverflow(patch_size=16, channel_ratio=6, embed_dim=576, stage_point=[1, 4, 8, 12],
- num_heads=9, mlp_ratio=4, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True, seq_length=args.seq_length)
- elif args.model_size == 'x1':
- # x1 model - 150M
- net = ConformerOverflow(patch_size=16, channel_ratio=8, embed_dim=768, stage_point=[1, 4, 8, 12],
- num_heads=12, mlp_ratio=4, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True, seq_length=args.seq_length)
- elif args.model_size == 'x2':
- # x2 model - 260M
- net = ConformerOverflow(patch_size=16, channel_ratio=10, embed_dim=896, stage_point=[1, 4, 12, 16],
- num_heads=14, mlp_ratio=4, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True, seq_length=args.seq_length)
- elif args.model_size == 'x3':
- # x3 model - 600M
- net = ConformerOverflow(patch_size=16, channel_ratio=12, embed_dim=1024, stage_point=[1, 4, 12, 20],
- num_heads=16, mlp_ratio=4, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True, seq_length=args.seq_length)
- elif args.model_size == 'x4':
- # x4 model - 1000M/10e
- net = ConformerOverflow(patch_size=16, channel_ratio=16, embed_dim=1408, stage_point=[1, 4, 12, 20],
- num_heads=16, mlp_ratio=6144/1408, qkv_bias=False, cls_token=True,
- num_classes=args.nb_classes, drop_rate=args.drop,
- drop_path_rate=args.drop_path, attn_drop_rate=args.drop_block,
- batch_size=args.batch_size,
- weighted_fusion=True)
-
-
-
- print('Batch Size:', args.batch_size)
- print('Learning rate:', args.lr)
-
- params_count_cnn = 0
- params_count_trans = 0
- for name, params in net.parameters_and_names():
- if "cnn_block" in name or "fusion_block" in name:
- params_count_cnn += np.prod(params.shape)
- if "trans_block" in name:
- params_count_trans += np.prod(params.shape)
- # print("*** name: ", name, " *** params: ", params)
- print('total cnn_block number: ', params_count_cnn)
- print('total trans_block number: ', params_count_trans)
-
- total_params = 0
- for param in net.trainable_params():
- total_params += np.prod(param.shape)
- print('total number: ', total_params)
-
- batch_num = dataset_train.get_dataset_size()
-
- # load parameters from checkpoint
- if not (args.finetune == ''):
- ckpt_param_dict = load_pre_trained_checkpoint(args, rank=rank_id)
- load_param_into_net(net, ckpt_param_dict)
-
- # using sync batchnorm
- if args.sync_bn:
- groups = build_groups(rank_size)
- convert_sync_batchnorm(net, groups)
- # convert_sync_batchnorm_old(net)
-
- # set the allreduce fusion configuration
- set_allreduce_fusion(net, 120)
-
- # create loss funciton
- ls = CrossEntropySmooth(reduction="mean")
-
- # create net work with loss fucntion
- net_with_loss = NetWithLossCell(net, ls)
-
- # create learning rate schedler
- if args.sched == 'cosine':
- lr_scheduler = LearningRate(args.lr,
- args.min_lr,
- args.warmup_lr,
- args.epochs,
- batch_num,
- warmup_epochs=args.warmup_epochs)
- else: # linear
- lr_scheduler = LearningRate(args.lr,
- args.min_lr,
- args.warmup_lr,
- args.epochs,
- batch_num,
- warmup_epochs=args.warmup_epochs,
- use_cosine=False)
- lr = lr_scheduler.get_lr()
-
- # create optimizer with fp32 status
- opt = FP32StateAdamWeightDecay(net.trainable_params(),
- learning_rate=lr,
- eps=args.opt_eps,
- # beta2=0.995,
- weight_decay=args.weight_decay)
-
- # create dynamic loss scale and gradients clipping cell
- if args.clip_grad > 0.0: # default is enabled
- scale_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**24,
- scale_factor=2,
- scale_window=2000)
- net_with_grads = TrainOneStepWithLossScaleCellAndClip(network=net_with_loss,
- optimizer=opt,
- scale_update_cell=scale_cell,
- clip_value=args.clip_grad)
- else:
- net_with_grads = net_with_loss
- loss_scale = DynamicLossScaleManager()
-
- # load parameters from beakpoint
- if not (args.breakpoint == ''):
- ckpt_param_dict = load_breakpoint_checkpoint(args, rank=rank_id)
- load_param_into_net(net_with_grads, ckpt_param_dict)
-
- # create the net for evaluation when training
- eval_net = WithEvalCell(net, ls)
-
- # evaluating metrics
- metrics = {'acc': Accuracy(), 'top1_acc': Top1CategoricalAccuracy(), 'top5_acc': Top5CategoricalAccuracy()}
-
- # create the model
- if args.clip_grad > 0.0:
- model = Model(net_with_grads, metrics=metrics, eval_network=eval_net, eval_indexes=[0, 1, 2])
- else:
- model = Model(net_with_grads, optimizer=opt, metrics=metrics, eval_network=eval_net, eval_indexes=[0, 1, 2], loss_scale_manager=loss_scale)
-
- # create callbacks
- call_back = list()
- if rank_id % 8 == 0:
- loss_cb = LossMonitor();
- call_back.append(loss_cb)
- if training_cloud:
- loss_summary = LossSummaryCallback(bucket=dst_url+'loss_summary/')
- call_back.append(loss_summary)
- time_cb = TimeMonitor(data_size=batch_num); call_back.append(time_cb)
- if args.save_checkpoint and rank_id == 0:
- ckpt_save_dir = set_save_ckpt_dir(args)
- config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_epochs * batch_num,
- keep_checkpoint_max=args.keep_checkpoint_max,
- exception_save=True)
- ckpt_cb = ModelCheckpoint(prefix='conformer', directory=ckpt_save_dir, config=config_ck)
- call_back.append(ckpt_cb)
-
- if args.eval:
- eval_cb = run_eval(args, rank_size, rank_id, model, src_url=src_url, dst_url=dst_url)
- call_back.append(eval_cb)
-
- if args.copy_ckpt and rank_id == 0:
- cp_ckpt_cb = copy_checkpoint(args, rank_size, rank_id, model, src_url=ckpt_save_dir, dst_url=args.checkpoint_store_path)
- call_back.append(cp_ckpt_cb)
-
- # begin to train the model
- try:
- print("training...")
- # sink_size = 2000
- # actual_epoch_num = int(args.epochs * dataset_train.get_dataset_size() / sink_size)
- # model.train(actual_epoch_num, dataset_train, callbacks=call_back, dataset_sink_mode=True, sink_size=sink_size)
-
- model.train(args.epochs, dataset_train, callbacks=call_back, dataset_sink_mode=False)
- #
- # model.train(args.epochs, dataset_train, callbacks=call_back, dataset_sink_mode=True, sink_size=10)
- except Exception as e:
- print("run exception e:{}".format(e))
- if training_cloud:
- if rank_id % 8 == 0:
- files = os.listdir(src_url)
- for file in files:
- os.system("mv {}{} {}{}".format(src_url, file, src_url, str(rank_id//8)+file))
- mox.file.copy_parallel(src_url=src_url, dst_url=dst_url)
- if args.save_checkpoint and rank_id == 0:
- mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=args.checkpoint_store_path)
- finally:
- if training_cloud:
- if rank_id % 8 == 0:
- files = os.listdir(src_url)
- for file in files:
- os.system("mv {}{} {}{}".format(src_url, file, src_url, str(rank_id//8)+file))
- mox.file.copy_parallel(src_url=src_url, dst_url=dst_url)
- if args.save_checkpoint and rank_id == 0:
- mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=args.checkpoint_store_path)
- # print(os.listdir(args.checkpoint_path))
- if args.profile_path:
- profiler.analyse()
- if training_cloud and rank_id == 0:
- mox.file.copy_parallel(src_url=args.profile_path, dst_url=dst_url+"/profile/")
- print("End training")
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser('Training and evaluation script', parents=[get_args_parser()])
- args = parser.parse_args()
- if args.output_dir:
- Path(args.output_dir).mkdir(parents=True, exist_ok=True)
- main(args)
|