|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train"""
- import os
-
- from mindspore import Model
- import moxing as mox
- from mindspore import context
- from mindspore import nn
- from mindspore.common import set_seed
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.communication.management import init, get_rank, get_group_size
-
- from src.args import args
- from src.tools.callback import EvaluateCallBack
- from src.tools.cell import cast_amp
- from src.tools.criterion import get_criterion, NetWithLoss
- from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
- from src.tools.optimizer import get_optimizer
- from mindspore.context import ParallelMode
-
- # set device_id and init
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- context.set_context(device_id=device_id)
- init()
- ### Defines whether the task is a training environment or a debugging environment ###
- def WorkEnvironment(environment):
- if environment == 'train':
- workroot = '/home/work/user-job-dir'
- elif environment == 'debug':
- workroot = '/home/work'
- print('current work mode:' + environment + ', workroot:' + workroot)
- return workroot
-
- ### Copy single dataset from obs to training image###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- return
- ### Copy the output model to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
-
- def train():
-
- ### defining the training environment
- environment = 'train'
- workroot = WorkEnvironment(environment)
- ###Initialize the data and model directories in the training image###
- data_dir = workroot + '/data'
- train_dir = workroot + '/model'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
-
- ### Copy the dataset from obs to the training image ###
- ObsToEnv(args.data_url,data_dir)
-
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
-
- # assert args.crop, f"{args.arch} is only for evaluation"
- # set_seed(args.seed)
- # mode = {
- # 0: context.GRAPH_MODE,
- # 1: context.PYNATIVE_MODE
- # }
- # context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
- # context.set_context(enable_graph_kernel=False)
- # if args.device_target == "Ascend":
- # context.set_context(enable_auto_mixed_precision=True)
- rank = get_rank()
-
- # get model and cast amp_level
- args.data_url = os.path.join(data_dir,'imagenet')
- net = get_model(args)
- cast_amp(net)
- criterion = get_criterion(args)
- net_with_loss = NetWithLoss(net, criterion)
- if args.pretrained:
- pretrained(args, net)
-
- data = get_dataset(args)
- batch_num = data.train_dataset.get_dataset_size()
- optimizer = get_optimizer(args, net, batch_num)
- # save a yaml file to read to record parameters
-
- net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
-
- eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
- eval_indexes = [0, 1, 2]
- model = Model(net_with_loss, metrics={"acc", "loss"},
- eval_network=eval_network,
- eval_indexes=eval_indexes)
-
- config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
- keep_checkpoint_max=args.save_every)
- time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
-
- ckpt_save_dir = train_dir
- if args.run_modelarts:
- ckpt_save_dir = "/cache/ckpt_" + str(rank)
-
- ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=ckpt_save_dir,
- config=config_ck)
- loss_cb = LossMonitor()
- eval_cb = EvaluateCallBack(model, eval_dataset=data.val_dataset, src_url=ckpt_save_dir,
- train_url=os.path.join(args.train_url, "ckpt_" + str(rank)),
- save_freq=args.save_every)
-
- print("begin train")
- model.train(int(args.epochs - args.start_epoch), data.train_dataset,
- callbacks=[time_cb, ckpoint_cb, loss_cb, eval_cb],
- dataset_sink_mode=True)
- print("train success")
-
- if args.run_modelarts:
- import moxing as mox
- mox.file.copy_parallel(src_url=ckpt_save_dir, dst_url=os.path.join(args.train_url, "ckpt_" + str(rank)))
-
-
- ###Copy the trained model data from the local running environment back to obs,
- ###and download it in the training task corresponding to the Qizhi platform
- EnvToObs(train_dir, args.train_url)
-
-
- if __name__ == '__main__':
- train()
|