|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """train net."""
- import os
- import argparse
- import mindspore
- from mindspore import context
- from mindspore import Tensor
- from mindspore.nn import Adam
- from mindspore.train.model import Model
- from mindspore.context import ParallelMode
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.train.loss_scale_manager import FixedLossScaleManager
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.common import set_seed
- from mindspore.train.summary.summary_record import SummaryRecord
- import mindspore.nn as nn
- import mindspore.common.initializer as weight_init
- from mindspore import Model, nn, DynamicLossScaleManager
- from src.dataset import create_train_dataset
- from src.gaitset import SetNet
- from src.lr_generator import get_lr, get_consine_lr
- from src.triplet_loss import Triplet_loss
- from src.CrossEntropySmooth import CrossEntropySmooth
- from src.callbacks import CustomCheckpointSaver, CustomLossMonitor
- from src.adam_clip import AdamClipped
- from mindspore.train.summary.summary_record import SummaryRecord
- from src.device_adapter import get_device_id, get_device_num
-
-
-
- parser = argparse.ArgumentParser(description='Gait recognition')
- parser.add_argument("--config")
-
-
- #define 2 parameters for running on modelArts
- #data_url,train_url 数据集路径、输出模型路径
- parser.add_argument("--data_url", help='path to training/inference dataset folder',default='./data')
- parser.add_argument("--train_url",help='model folder to save/load',default='./model')
-
- parser.add_argument('--dataset_path',
- type=str,
- default="./Data",
- help='path where the dataset is saved')
- parser.add_argument('--save_checkpoint_path',
- type=str,
- default="./ckpt",
- help='if is test, must provide\
- path where the trained ckpt file')
- set_seed(1)
-
-
-
- # import sys
- # from IPython.core import ultratb
- # sys.excepthook = ultratb.FormattedTB(mode='Verbose', color_scheme='Linux', call_pdb=False)
-
- if __name__ == '__main__':
- from src.config import cfg
- print("1111")
-
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
-
- args = parser.parse_args()
- #device_num = get_device_num()
- context.set_context(mode=context.GRAPH_MODE,
- device_target=args.device_target)
- print(args.data_url)
- print("2222")
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/inputs/data'
- obs_train_url = args.train_url
- args.train_url = 'home/work/user-job-dir/outputs/model'
-
- print("3333")
- try:
- mox.file.copy_parallel(obs_data_url,args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url,args.data_url))
-
- except Exception as e:
- print('moxing download {} to {} failed'.format(obs_data_url,args.data_url) + str(e))
-
- print("4444")
- cfg.dataset_path = args.data_url
- cfg.ckpt_save_dir = args.train_url
- # init context
- if cfg.run_distributed:
-
- init()
- context.reset_auto_parallel_context()
- rank = get_rank()
- device_num = get_group_size()
- context.set_auto_parallel_context(device_num=device_num,
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True)
-
-
-
-
-
- # create dataset
- dataset = create_train_dataset(cfg)
- step_size = dataset.get_dataset_size()
- print(f"step_size: {step_size}")
-
- # define net
- net = SetNet(training=True)
-
- # init weight
- for _, cell in net.cells_and_names():
- if isinstance(cell, nn.Conv2d):
- cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
- cell.weight.shape,
- cell.weight.dtype))
-
- #d = dataset.create_dict_iterator()
- #data = next(d)
- #print(data["label"])
- #output = net(data['data'])
- #print(output.shape)
-
- #init loss
- #loss = CrossEntropySmooth()
- # a = loss(output, data["label"])
- # print(a)
-
- # init lr
- #lr = 0.000025
- #lr = get_lr(0.000025, step_size)
- lr = get_consine_lr(0.00002, step_size)
-
- # define opt
- decayed_params = []
- no_decayed_params = []
- for param in net.trainable_params():
- if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
- decayed_params.append(param)
- else:
- no_decayed_params.append(param)
- group_params = [{'params': decayed_params, 'weight_decay': 0.0001},
- {'order_params': net.trainable_params()},
- {'params': no_decayed_params}]
-
- opt = AdamClipped(params=group_params, learning_rate=lr)
- #opt = mindspore.nn.Adam(params=group_params, learning_rate=lr, loss_scale=1024)
- #opt = mindspore.nn.SGD(params=group_params, learning_rate=lr, loss_scale=1024)
-
- # define loss
- #loss = CrossEntropySmooth()
- loss = Triplet_loss()
-
- # define model
- #loss_scale = DynamicLossScaleManager()
- loss_scale = FixedLossScaleManager(1024, drop_overflow_update=True)
-
- net_with_criterion = nn.WithLossCell(net,loss)
- model = Model(net_with_criterion,
- #loss_fn=loss,
- optimizer=opt,
- loss_scale_manager=loss_scale,
- amp_level="O2",
- keep_batchnorm_fp32=False)
-
- # define callbacks
- summary_dir = cfg.summary_dir
- time_cb = TimeMonitor(data_size=step_size)
- cb = [time_cb]
-
- if cfg.save_checkpoint:
- config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
- keep_checkpoint_max=cfg.keep_checkpoint_max)
- ckpt_save_dir = cfg.ckpt_save_dir
- if cfg.run_distributed:
- ckpt_save_dir = os.path.join(cfg.ckpt_save_dir, "ckpt_" + str(rank) + "/")
- summary_dir += str(rank)
- ckpt_cb = ModelCheckpoint(prefix="gaitset", directory=ckpt_save_dir, config=config_ck)
- ckpt_cb2 = CustomCheckpointSaver(step_to_enable=3000, which_step=1, save_dir=ckpt_save_dir)
- cb += [ckpt_cb, ckpt_cb2]
-
- #mindinsight
- with SummaryRecord(summary_dir) as summary_record:
- loss_cb = CustomLossMonitor(summary_record=summary_record, frequency=10)
- cb += [loss_cb]
-
- # train model
- model.train(1,
- dataset,
- callbacks=cb,
- sink_size=step_size,
- dataset_sink_mode=False)
-
- try:
- mox.file.copy_parallel(args.train_url,obs.train_url)
- print("Successfully Download {} to {}".format(args.train_url,obs.train_url))
-
- except Exception as e:
- print('moxing download {} to {} failed'.format(args.train_url,obs.train_url) + str(e))
|