|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- python train.py
- """
- import argparse
- import os
- import json
- import moxing as mox
- import time
- import numpy as np
- import mindspore
- import mindspore.nn as nn
- from mindspore import context, Tensor
- import mindspore.ops as ops
- from mindspore.train.model import Model, ParallelMode
- from mindspore import dtype as mstype
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.communication.management import init, get_rank
- from mindspore.parallel import _cost_model_context as cost_model_context
- from mindspore.parallel import set_algo_parameters
-
- from src.dataset import create_dataset
- from src.iresnet import iresnet100, iresnet50
- from src.loss import PartialFC
- mindspore.common.set_seed(1024)
- parser = argparse.ArgumentParser(description='Training')
-
- # Datasets
- # parser.add_argument('--train_url', default='obs://duss/code/arcface_origin/r100_deploy_init_webface_1116/', type=str,
- # help='output path')
- # parser.add_argument('--data_url', default='obs://duss/code/arcface/faces_webface_112x112_train/', type=str)
-
- # Optimization options
- parser.add_argument('--epochs', default=100, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('--num_classes', default=10572, type=int, metavar='N',
- help='num of classes')
- parser.add_argument('--batch_size', default=16, type=int, metavar='N',
- help='train batchsize (default: 256)')
- parser.add_argument('--lr', '--learning-rate', default=0.08, type=float,
- metavar='LR', help='initial learning rate')
- parser.add_argument('--schedule', type=int, nargs='+', default=[40, 64,84],
- help='Decrease learning rate at these epochs.')
- parser.add_argument('--gamma', type=float, default=0.1,
- help='LR is multiplied by gamma on schedule.')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)')
- # Device options
- parser.add_argument('--device_target', type=str,
- default='Ascend', choices=['GPU', 'Ascend'])
- parser.add_argument('--device_num', type=int, default=1)
- parser.add_argument('--device_id', type=int, default=0)
- parser.add_argument('--modelarts', action="store_true", help="using modelarts")
-
- parser.add_argument('--ckpt_url',type=str,
- help='')
- parser.add_argument('--train_url', default='/cache/output/', type=str,
- help='output path')
- parser.add_argument('--multi_data_url',
- help='path to multi dataset',
- default= '/cache/data/')
- args = parser.parse_args()
-
- ############################################################
- # 智算网络功能函数
- ############################################################
- # Copy multi-dataset from obs to training image and unzip
- def C2netMultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- # multi_data_json = json.loads(multi_data_url)
- # for i in range(len(multi_data_json)):
- multi_data_json = json.loads(multi_data_url)
- zipfile_path = data_dir + "/" + multi_data_json[0]["dataset_name"]
- # try:
- print("#############################################################")
- mox.file.copy(multi_data_json[0]["dataset_url"], zipfile_path)
- print("Successfully Download {} to {}".format(multi_data_url,zipfile_path))
- print("#############################################################")
- filePath = data_dir + "/"
- #get filename and unzip the dataset
- # filename = CASIA-WebFace
- print("#####################zipfile_path 和 file_path",zipfile_path, filePath)
- print("#####################zipfile_path 和 file_path",zipfile_path, filePath)
- # if not os.path.exists(filePath):
- # os.makedirs(filePath)
- print('################### start zip ####################')
- # zip_command = "unzip -o -q /cache/data_path_" + os.getenv('DEVICE_ID') \
- # + "/CASIA-WebFace.zip -d /cache/data_path_" + \
- # os.getenv('DEVICE_ID')
- os.system("unzip -o -q {} -d {}".format(zipfile_path, filePath))
- print('################### end zip ####################')
-
- # except Exception as e:
- # print('moxing download dataset to failed')
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
-
- # Copy the output to obs
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,
- obs_train_url) + str(e))
- return
-
- # Download the input from Qizhi And Init
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if args.device_num == 1:
- C2netMultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0 or local_rank%2==0 or local_rank%4==0:
- C2netMultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
-
- # Upload the output to Qizhi
- def UploadToQizhi(train_dir, obs_train_url):
- device_num = int(os.getenv('RANK_SIZE'))
- local_rank=int(os.getenv('RANK_ID'))
- if device_num == 1:
- EnvToObs(train_dir, obs_train_url)
- if device_num > 1:
- if local_rank%8==0:
- EnvToObs(train_dir, obs_train_url)
- return
- ##################################################################
-
- def get_rank_id():
- global_rank_id = os.getenv('RANK_ID', '0')
- return int(global_rank_id)
-
- def lr_generator(lr_init, total_epochs, steps_per_epoch):
- """lr_generator
- """
- lr_each_step = []
- for i in range(total_epochs):
- if i in args.schedule:
- lr_init *= args.gamma
- for _ in range(steps_per_epoch):
- lr_each_step.append(lr_init)
- lr_each_step = np.array(lr_each_step).astype(np.float32)
- return Tensor(lr_each_step)
-
-
- class MyNetWithLoss(nn.Cell):
- """
- WithLossCell
- """
- def __init__(self, backbone, cfg):
- super(MyNetWithLoss, self).__init__(auto_prefix=False)
- self._backbone = backbone.to_float(mstype.float16)
- self._loss_fn = PartialFC(num_classes=cfg.num_classes,
- world_size=cfg.device_num).to_float(mstype.float32)
- self.L2Norm = ops.L2Normalize(axis=1)
-
- def construct(self, data, label):
- out = self._backbone(data)
- loss = self._loss_fn(out, label)
- return loss
-
-
- if __name__ == "__main__":
- if get_rank_id() == 0:
- print(args)
- data_dir = '/cache/data'
- train_dir = '/cache/output'
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
- ###Initialize and copy data to training image
- DownloadFromQizhi(args.multi_data_url, data_dir)
-
- train_epoch = args.epochs
- target = args.device_target
- # context.set_context(mode=context.GRAPH_MODE,
- # device_target=target, save_graphs=False)
- device_id = args.device_id
- # if args.device_num > 1:
- # if target == 'Ascend':
- # device_id = int(os.getenv('DEVICE_ID'))
- # context.set_context(device_id=device_id)
- # context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
- # gradients_mean=True,
- # )
- # cost_model_context.set_cost_model_context(device_memory_capacity=32.0 * 1024.0 * 1024.0 * 1024.0,
- # costmodel_gamma=0.001,
- # costmodel_beta=280.0)
- # set_algo_parameters(elementwise_op_strategy_follow=True)
- # init()
- # elif target == 'GPU':
- # init()
- # context.set_auto_parallel_context(device_num=args.device_num,
- # parallel_mode=ParallelMode.DATA_PARALLEL,
- # gradients_mean=True,
- # auto_parallel_search_mode="recursive_programming")
- # else:
- # device_id = int(os.getenv('DEVICE_ID'))
-
- if args.modelarts:
- import moxing as mox
-
- mox.file.copy_parallel(
- src_url=args.data_url, dst_url='/cache/data_path_' + os.getenv('DEVICE_ID'))
- zip_command = "unzip -o -q /cache/data_path_" + os.getenv('DEVICE_ID') \
- + "/faces_webface_112x112_train.zip -d /cache/data_path_" + \
- os.getenv('DEVICE_ID')
- os.system(zip_command)
- train_dataset = create_dataset(dataset_path='/cache/data_path_' + os.getenv('DEVICE_ID') + '/faces_webface_112x112_train/',
- do_train=True,
- repeat_num=1, batch_size=args.batch_size, target=target)
- else:
- train_dataset = create_dataset(dataset_path=data_dir + '/faces_webface_112x112_train/', do_train=True,
- repeat_num=1, batch_size=args.batch_size, target=target)
- step = train_dataset.get_dataset_size()
- lr = lr_generator(args.lr, train_epoch, steps_per_epoch=step)
- net = iresnet50()
- train_net = MyNetWithLoss(net, args)
- optimizer = nn.SGD(params=train_net.trainable_params(), learning_rate=lr / 512 * args.batch_size * args.device_num,
- momentum=args.momentum, weight_decay=args.weight_decay)
-
- model = Model(train_net, optimizer=optimizer)
-
- config_ck = CheckpointConfig(
- save_checkpoint_steps=60, keep_checkpoint_max=20)
- if args.device_num == 1:
- outputDirectory = train_dir
- if args.device_num > 1:
- outputDirectory = train_dir + "/" + str(get_rank()) + "/"
- if args.modelarts:
- ckpt_cb = ModelCheckpoint(prefix="ArcFace-r100-webface-origin-init", config=config_ck,
- directory='/cache/train_output/')
- else:
- ckpt_cb = ModelCheckpoint(prefix="ArcFace-r100-webface-init-ep100", config=config_ck,
- directory=outputDirectory)
- time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
- loss_cb = LossMonitor()
- cb = [ckpt_cb, time_cb, loss_cb]
- if args.device_num == 1:
- model.train(train_epoch, train_dataset,
- callbacks=cb, dataset_sink_mode=True)
- elif args.device_num > 1 and get_rank() % 8 == 0:
- model.train(train_epoch, train_dataset,
- callbacks=cb, dataset_sink_mode=True)
- else:
- model.train(train_epoch, train_dataset, dataset_sink_mode=True)
- if args.modelarts:
- mox.file.copy_parallel(
- src_url='/cache/train_output', dst_url=args.train_url)
- UploadToQizhi(train_dir,args.train_url)
|