|
- import argparse
- import datetime
- import os
- import warnings
- from pathlib import Path
- import time
-
- import moxing as mox
-
- from mindspore import context
- from mindspore import dataset as de
- from mindspore import nn
- from mindspore import set_seed
- from mindspore import dtype as mstype
- from mindspore.communication.management import get_group_size
- from mindspore.communication.management import get_rank
- from mindspore.communication.management import init
- from mindspore.context import ParallelMode
- from mindspore.nn.optim import Adam
- from mindspore.train.callback import CheckpointConfig
- from mindspore.train.callback import ModelCheckpoint
- from mindspore.train.callback import RunContext
- from mindspore.train.callback import _InternalCallbackParam
- from mindspore.profiler import Profiler
-
- from src.pointpillars import PointPillarsWithLossCell
- from src.pointpillars import TrainingWrapper
- from src.utils import get_config
- from src.utils import get_model_dataset
-
- warnings.filterwarnings('ignore')
-
-
- def set_default():
- """set default"""
- set_seed(0)
-
- cfg_path = Path("/home/work/user-job-dir/V0026/configs/car_xyres16_modelarts.yaml")
- save_path = Path("train_output/")
- save_path.mkdir(exist_ok=True, parents=True)
-
- cfg = get_config(cfg_path)
-
- context.set_context(mode=context.GRAPH_MODE, device_target=cfg["device_target"])
-
- if cfg["is_distributed"]:
- # init distributed
- init()
- rank = get_rank()
- device_num = get_group_size()
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.DATA_PARALLEL,
- gradients_mean=True,
- device_num=device_num,
- )
- else:
- rank = 0
- device_num = 1
- device_id = int(os.getenv('DEVICE_ID', '0'))
- context.set_context(device_id=device_id)
-
- return cfg, rank, device_num
-
-
- def train(args):
- data_dir = '/home/work/user-job-dir/data' #数据集存放路径
- train_dir = '/home/work/user-job-dir/model' #模型存放路径
- #初始化数据存放目录
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- #初始化模型存放目录
- obs_train_url = args.train_url
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
- ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
- # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
- #创建数据存放的位置
- obs_data_url = args.data_url
- #将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
- """run train"""
- cfg, rank, device_num = set_default()
-
- save_ckpt_log_flag = rank == 0
-
- train_cfg = cfg['train_config']
-
- profiler = Profiler(output_path = './profiler_data')
- pointpillarsnet, dataset = get_model_dataset(cfg, True)
- if save_ckpt_log_flag:
- print('PointPillarsNet created', flush=True)
-
- input_cfg = cfg['train_input_reader']
- n_epochs = input_cfg['max_num_epochs']
- batch_size = input_cfg['batch_size']
-
- steps_per_epoch = int(len(dataset) / batch_size / device_num)
- lr_cfg = train_cfg['learning_rate']
- lr = nn.exponential_decay_lr(
- learning_rate=lr_cfg['initial_learning_rate'],
- decay_rate=lr_cfg['decay_rate'],
- total_step=n_epochs * steps_per_epoch,
- step_per_epoch=steps_per_epoch,
- decay_epoch=lr_cfg['decay_epoch'],
- is_stair=lr_cfg['is_stair']
- )
- optimizer = Adam(
- pointpillarsnet.trainable_params(),
- learning_rate=lr,
- weight_decay=train_cfg['weight_decay']
- )
-
- pointpillarsnet_wloss = PointPillarsWithLossCell(pointpillarsnet, cfg['model'])
- pointpillarsnet_wloss.to_float(mstype.float16)
- network = TrainingWrapper(pointpillarsnet_wloss, optimizer)
-
- train_column_names = dataset.data_keys
- sampler = de.DistributedSampler(device_num, rank)
- ds = de.GeneratorDataset(
- dataset,
- column_names=train_column_names,
- python_multiprocessing=True,
- num_parallel_workers=1,
- max_rowsize=100,
- sampler=sampler
- )
- ds = ds.batch(batch_size, drop_remainder=True)
- ds = ds.repeat(n_epochs)
- data_loader = ds.create_dict_iterator(num_epochs=n_epochs)
- network.set_train()
-
- if save_ckpt_log_flag:
- ckpt_config = CheckpointConfig(
- save_checkpoint_steps=steps_per_epoch,
- keep_checkpoint_max=train_cfg['keep_checkpoint_max']
- )
- ckpt_cb = ModelCheckpoint(
- config=ckpt_config,
- directory=train_dir,
- prefix='pointpillars'
- )
- cb_params = _InternalCallbackParam()
- cb_params.train_network = pointpillarsnet
- cb_params.epoch_num = n_epochs
- cb_params.cur_epoch_num = 1
- run_context = RunContext(cb_params)
- ckpt_cb.begin(run_context)
-
- log_freq = train_cfg['log_frequency_step']
- old_progress = -1
- start = time.time()
- for i, data in enumerate(data_loader):
- voxels = data["voxels"]
- num_points = data["num_points"]
- coors = data["coordinates"]
- labels = data['labels']
- reg_targets = data['reg_targets']
- batch_anchors = data["anchors"]
- bev_map = data.get('bev_map', False) # value not used if use_bev = False
-
- loss = network(voxels, num_points, coors, bev_map, labels, reg_targets, batch_anchors)
- if save_ckpt_log_flag:
- cb_params.cur_step_num = i + 1 # current step number
- cb_params.batch_num = i + 2
- ckpt_cb.step_end(run_context)
-
- if i % log_freq == 0:
- time_used = time.time() - start
- epoch = i // steps_per_epoch
- fps = (i - old_progress) * batch_size * device_num / time_used
- date_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- print(f'{date_time} epoch:{epoch}, iter:{i}, '
- f'loss:{loss}, fps:{round(fps, 2)} imgs/sec, step time: {time_used/steps_per_epoch} ms',
- flush=True)
- start = time.time()
- old_progress = i
-
- if (i + 1) % steps_per_epoch == 0:
- cb_params.cur_epoch_num += 1
- profiler.analyse()
- ######################## 将输出的模型拷贝到obs(固定写法) ########################
- # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir, obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir, obs_train_url) + str(e))
- ######################## 将输出的模型拷贝到obs ########################
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--device_target', default='Ascend', help='device target')
- parser.add_argument('--data_url', required=True, help='')
- parser.add_argument('--train_url', required=True, help='')
- parse_args = parser.parse_args()
- train(parse_args)
|