|
- import os
- from time import time
-
- from mindspore import context, nn
- from mindspore.train.model import Model
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
-
- import mindspore.ops as ops
-
-
- from src.dataset.ModelNet import create_modelnet40_dataset
- from src.model.PT_classification import PointTransformerCls
- from src.util.common import timeit
- from src.config.default import get_config
-
-
-
- class CustomTrainOneStepCell(nn.Cell):
- """自定义训练网络"""
-
- def __init__(self, network, optimizer, sens=1.0):
- """入参有三个:训练网络,优化器和反向传播缩放比例"""
- super(CustomTrainOneStepCell, self).__init__(auto_prefix=False)
- self.network = network # 定义前向网络
- self.network.set_grad() # 构建反向网络
- self.optimizer = optimizer # 定义优化器
- self.weights = self.optimizer.parameters # 待更新参数
- self.grad = ops.GradOperation(get_by_list=True, sens_param=True) # 反向传播获取梯度
- self.sens = sens # 网络输出的缩放比
-
- def construct(self, *inputs):
- loss = self.network(*inputs) # 执行前向网络,计算当前输入的损失函数值
- sens = ops.fill(loss.dtype, loss.shape, self.sens) # 对损失值执行缩放
- grads = self.grad(self.network, self.weights)(*inputs, sens) # 进行反向传播,计算梯度
- self.optimizer(grads) # 使用优化器更新权重参数
- return loss
-
-
- def train(cfg):
- '''DATA LOADING'''
- print('Load dataset ...')
- data_path = os.path.join(cfg.data_url, "modelnet40_normal_resampled")
- traindataset = create_modelnet40_dataset(mode='train',
- data_root=data_path,
- num_points = 1024,
- batch_size=cfg.batch_size)
-
- testdata = create_modelnet40_dataset(mode='test',
- data_root=data_path,
- num_points = 1024)
-
- step_size = traindataset.get_dataset_size()
-
- '''MODEL LOADING'''
- net = PointTransformerCls()
- # TODO: CHECKPOINT LOADING
- # try:
- # checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
- # start_epoch = checkpoint['epoch']
- # classifier.load_state_dict(checkpoint['model_state_dict'])
- # log_string('Use pretrain model')
- # except:
- # log_string('No existing model, starting training from scratch...')
- # start_epoch = 0
- loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
-
- opt = nn.Adam(params=net.trainable_params(),
- learning_rate=0.001,
- beta1=0.9,
- beta2=0.999,
- eps=1e-08,
- weight_decay=1e-4
- )
-
- if cfg.device_target != "Ascend":
- model = Model(net, loss, opt)
- else:
- model = Model(net, loss, opt, amp_level="O2")
-
- config_ck = CheckpointConfig(
- save_checkpoint_steps = step_size,
- keep_checkpoint_max = 10)
-
- ckpoint_cb = ModelCheckpoint(prefix="checkpoint_pointcloud_cls",
- directory=cfg.train_url,
- config=config_ck)
- time_cb = TimeMonitor(step_size)
- '''TRAINING'''
- if True:
- print("============== Starting Training ==============")
-
- model.train(200,
- traindataset,
- callbacks=[time_cb, ckpoint_cb, LossMonitor()])
- else:
- net_with_criterion = nn.WithLossCell(net, loss)
- train_net = CustomTrainOneStepCell(net_with_criterion, opt)
- steps = traindataset.get_dataset_size()
- step = 0
- for epoch in range(200):
- for data in traindataset.create_dict_iterator():
- b = time()
- timeit(f"start step: [{step} / {steps}]", b)
- point, label = data["point_set"],data["cls"]
- result = train_net(point, label) # 输出损失值
-
- timeit(f"step: [{step} / {steps}] take ", b)
- print(f"Epoch: [{epoch} / 200], step: [{step} / {steps}], loss: {result}")
- step = step + 1
-
-
- if __name__ == '__main__':
-
- cfg = get_config()
-
- environment = 'train'
-
- workroot = '/home/work/user-job-dir' # 训练任务使用该参数
-
- print('current work mode:' + environment + ', workroot:' + workroot)
-
- context.set_context(mode=context.GRAPH_MODE,
- device_target="Ascend",
- device_id=5,
- enable_compile_cache=True)
-
- print(cfg.is_modelarts)
-
- ####################### 将数据集从obs拷贝到训练镜像中 (固定写法)###############
- # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径
- if cfg.is_modelarts:
- import moxing as mox
- obs_data_url = cfg.data_url
- cfg.data_url = '/home/work/user-job-dir/inputs/data/'
- obs_train_url = cfg.train_url
- cfg.train_url = '/home/work/user-job-dir/outputs/model/'
- try:
- mox.file.copy_parallel(obs_data_url, cfg.data_url)
- print("Successfully Download {} to {}".format(obs_data_url,
- cfg.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, cfg.data_url) + str(e))
- ######################## 将数据集从obs拷贝到训练镜像中 ########################
-
- train(cfg)
-
- ######################## 将输出的模型拷贝到obs(固定写法) ########################
- # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- if cfg.is_modelarts:
- try:
- mox.file.copy_parallel(cfg.train_url, obs_train_url)
- print("Successfully Upload {} to {}".format(cfg.train_url,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(cfg.train_url,
- obs_train_url) + str(e))
- ######################## 将输出的模型拷贝到obs ########################
|