|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- python train.py
- """
- import argparse
- import os
- import numpy as np
- import mindspore
- import mindspore.nn as nn
- from mindspore import context, Tensor
-
- import mindspore.ops as ops
- from mindspore.ops import composite as C
- from mindspore.ops import functional as F
- from mindspore.ops import operations as P
-
- from mindspore import context, Tensor
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- from mindspore.communication.management import get_group_size
-
- from mindspore.train.model import Model, ParallelMode
- from mindspore import dtype as mstype
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.communication.management import init, get_rank
- from mindspore.parallel import _cost_model_context as cost_model_context
- from mindspore.parallel import set_algo_parameters
-
- from src.dataset import create_dataset
- from src.iresnet import iresnet100
- from src.mobilefacenet import get_mbf
- from src.loss import PartialFC
-
- mindspore.common.set_seed(2022)
- parser = argparse.ArgumentParser(description='Training')
-
- # Datasets
- parser.add_argument('--train_url', default='obs://duss/code/Arcface-mindspore/r100_deploy_ms1mv2/', type=str,
- help='output path')
- parser.add_argument('--data_url', default='obs://duss/code/arcface/ms1mv2/', type=str)
- # Optimization options
- parser.add_argument('--epochs', default=25, type=int, metavar='N',
- help='number of total epochs to run')
- # parser.add_argument('--num_classes', default=10572, type=int, metavar='N',
- # help='num of classes')
- parser.add_argument('--num_classes', default=85742, type=int, metavar='N',
- help='num of classes')
- parser.add_argument('--batch_size', default=128, type=int, metavar='N',
- help='train batchsize (default: 256)')
- parser.add_argument('--lr', '--learning-rate', default=0.02, type=float,
- metavar='LR', help='initial learning rate')
- parser.add_argument('--schedule', type=int, nargs='+', default=[10, 16, 21],
- help='Decrease learning rate at these epochs.')
- parser.add_argument('--gamma', type=float, default=0.02,
- help='LR is multiplied by gamma on schedule.')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
- metavar='W', help='weight decay (default: 5e-4)')
- # Device options
- parser.add_argument('--device_target', type=str,
- default='Ascend', choices=['GPU', 'Ascend'])
- parser.add_argument('--device_num', type=int, default=4)
- parser.add_argument('--device_id', type=int, default=0)
- parser.add_argument('--modelarts', default=True, action="store_true", help="using modelarts")
-
- args = parser.parse_args()
-
- def get_rank_id():
- global_rank_id = os.getenv('RANK_ID', '0')
- return int(global_rank_id)
-
- def lr_generator(lr_init, total_epochs, steps_per_epoch):
- """lr_generator
- """
- lr_each_step = []
- for i in range(total_epochs):
- if i in args.schedule:
- lr_init *= args.gamma
- for _ in range(steps_per_epoch):
- lr_each_step.append(lr_init)
- lr_each_step = np.array(lr_each_step).astype(np.float32)
- return Tensor(lr_each_step)
-
-
- class MyNetWithLoss(nn.Cell):
- """
- WithLossCell
- """
- def __init__(self, backbone, cfg):
- super(MyNetWithLoss, self).__init__(auto_prefix=False)
- self._backbone = backbone#.to_float(mstype.float16)
- self._loss_fn = PartialFC(num_classes=cfg.num_classes,
- world_size=cfg.device_num).to_float(mstype.float32)
- self.L2Norm = ops.L2Normalize(axis=1)
-
- def construct(self, data, label):
- out = self._backbone(data)
- loss = self._loss_fn(out, label)
- return loss
-
-
- # clip_grad
- GRADIENT_CLIP_TYPE = 1
- GRANDIENT_CLIP_VALUE = 1.0
-
- clip_grad = C.MultitypeFuncGraph("clip_grad")
-
- @clip_grad.register("Number", "Number", "Tensor")
- def _clip_grad(clip_type, clip_value, grad):
- if clip_type not in (0, 1):
- return grad
- dt = F.dtype(grad)
- if clip_type == 0:
- new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
- F.cast(F.tuple_to_array((clip_value,)), dt))
- else:
- new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
- return new_grad
-
- class TrainingWrapper(nn.Cell):
- def __init__(self, network, optimizer, sens=1.0):
- super(TrainingWrapper, self).__init__(auto_prefix=False)
- self.network = network
- self.weights = mindspore.ParameterTuple(network.trainable_params())
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- class_list = [mindspore.context.ParallelMode.DATA_PARALLEL, mindspore.context.ParallelMode.HYBRID_PARALLEL]
- if self.parallel_mode in class_list:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
- self.hyper_map = mindspore.ops.HyperMap()
-
- def construct(self, *args):
- weights = self.weights
- loss = self.network(*args)
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- grads = self.grad(self.network, weights)(*args, sens)
-
- grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRANDIENT_CLIP_VALUE), grads)
- if self.reducer_flag:
- grads = self.grad_reducer(grads)
- return F.depend(loss, self.optimizer(grads))
-
-
- if __name__ == "__main__":
- if get_rank_id() == 0:
- print(args)
- train_epoch = args.epochs
- target = args.device_target
- # context.set_context(mode=context.GRAPH_MODE,
- # device_target=target, save_graphs=False)
- context.set_context(mode=context.GRAPH_MODE,
- device_target=target, save_graphs=False)
- device_id = args.device_id
- if args.device_num > 1:
- if target == 'Ascend':
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(device_id=device_id)
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- )
- cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
- costmodel_gamma=0.001,
- costmodel_beta=280.0)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- init()
- elif target == 'GPU':
- init()
- context.set_auto_parallel_context(device_num=args.device_num,
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- auto_parallel_search_mode="recursive_programming")
- else:
- device_id = int(os.getenv('DEVICE_ID'))
-
- if args.modelarts:
- import moxing as mox
-
- mox.file.copy_parallel(
- src_url=args.data_url, dst_url='/cache/data_path_' + os.getenv('DEVICE_ID'))
- zip_command = "unzip -o -q /cache/data_path_" + os.getenv('DEVICE_ID') \
- + "/ms1mv2.zip -d /cache/data_path_" + \
- os.getenv('DEVICE_ID')
- os.system(zip_command)
- train_dataset = create_dataset(dataset_path='/cache/data_path_' + os.getenv('DEVICE_ID') + '/faces_emore_train/',
- do_train=True,
- repeat_num=1, batch_size=args.batch_size, target=target)
- else:
- train_dataset = create_dataset(dataset_path=args.data_url, do_train=True,
- repeat_num=1, batch_size=args.batch_size, target=target)
- step = train_dataset.get_dataset_size()
- lr = lr_generator(args.lr, train_epoch, steps_per_epoch=step)
- net = iresnet100() ## resnet101
- train_net = MyNetWithLoss(net, args)
- optimizer = nn.SGD(params=train_net.trainable_params(), learning_rate=lr,
- momentum=args.momentum, weight_decay=args.weight_decay)
-
- train_net = TrainingWrapper(train_net, optimizer)
-
- model = Model(train_net)
-
- config_ck = CheckpointConfig(
- save_checkpoint_steps=60, keep_checkpoint_max=20)
- if args.modelarts:
- ckpt_cb = ModelCheckpoint(prefix="ArcFace-r100-ms1mv2", config=config_ck,
- directory='/cache/train_output/')
- else:
- ckpt_cb = ModelCheckpoint(prefix="ArcFace-", config=config_ck,
- directory=args.train_url)
- time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
- loss_cb = LossMonitor()
- cb = [ckpt_cb, time_cb, loss_cb]
- if args.device_num == 1:
- # model.train(train_epoch, train_dataset,
- # callbacks=cb, dataset_sink_mode=True)
- model.train(train_epoch, train_dataset,
- callbacks=cb, dataset_sink_mode=False)
- elif args.device_num > 1 and get_rank() % 8 == 0:
- # model.train(train_epoch, train_dataset,
- # callbacks=cb, dataset_sink_mode=True)
- model.train(train_epoch, train_dataset,
- callbacks=cb, dataset_sink_mode=False)
- else:
- # model.train(train_epoch, train_dataset, dataset_sink_mode=True)
- model.train(train_epoch, train_dataset, dataset_sink_mode=False)
- if args.modelarts:
- mox.file.copy_parallel(
- src_url='/cache/train_output', dst_url=args.train_url)
|