|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import argparse
- import os
- from mindspore import Tensor, Model, save_checkpoint, context
- from mindspore.communication.management import get_rank, init, get_group_size
- from mindspore.context import ParallelMode
- import mindspore.nn as nn
- from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
-
- from src.config import config
- from src.loss import CrossEntropyWithWeight, NetWithLoss, MultiLabelLoss
- from models.vnet import VNet
- from src.dataset import create_dataset
- from src.evaluation import EvaluateCallBack
-
- config.IS_MODELART = True
- config.run_distribute = True
-
- from moxing_adapter import sync_data
-
- parser = argparse.ArgumentParser(description='V-net')
-
- parser.add_argument('--train_url', required=False,
- default=None, help='Location of training outputs.')
- parser.add_argument('--data_url', required=False,
- default=None, help='Location of data.')
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'])
- parser.add_argument('--ckpt_url', required=False, default=None, help='Location of pretrained model.')
- args = parser.parse_args()
-
-
-
-
- def main(args):
- context.set_context(mode=context.GRAPH_MODE, device_id = int(os.environ["DEVICE_ID"]), device_target = "Ascend")
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- parameter_broadcast=True,
- search_mode="recursive_programming")
- init()
-
- net = VNet(elu=False)
-
- # 初始化数据集存放目录
- if not os.path.exists(config.MODELARTS.CACHE_INPUT):
- os.makedirs(config.MODELARTS.CACHE_INPUT)
- # 初始化模型存放目录
- if not os.path.exists(config.MODELARTS.CACHE_OUTPUT):
- os.makedirs(config.MODELARTS.CACHE_OUTPUT)
-
- sync_data(args.data_url, config.MODELARTS.CACHE_INPUT)
-
- train_dataset = create_dataset(is_train=True)
- val_dataset = create_dataset(is_train=False)
-
- optimizer = nn.Adam(net.trainable_params(), learning_rate=config.TRAIN.lr) # , weight_decay=config.TRAIN.WD
- # loss = CrossEntropyWithWeight(weights=config.TRAIN.loss_weights)
- # loss = CrossEntropyWithWeight(config.TRAIN.loss_weights)
- loss = MultiLabelLoss()
- dataset_size = train_dataset.get_dataset_size()
- time_cb = TimeMonitor(data_size=dataset_size)
- loss_cb = LossMonitor()# per_print_times=dataset_size
- callback_list = [time_cb,loss_cb]
-
- if config.TRAIN.with_eval:
- eval_cb = EvaluateCallBack(model=net,eval_dataset=val_dataset)
- callback_list.append(eval_cb)
-
- net_with_loss = NetWithLoss(net, loss_fn=loss)
- model = Model(network=net_with_loss, optimizer=optimizer, amp_level="O0")
-
- epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
- print("************ Start training now ************")
- print('start training, epoch size = %d' % epoch_size)
- model.train(epoch_size, train_dataset, dataset_sink_mode=False, callbacks=callback_list)
- print("************ Training complete ************")
-
- save_checkpoint(net, config.MODELARTS.CACHE_OUTPUT+ '/' + str(get_rank()) + "/trained_model_param.ckpt")
- sync_data(config.MODELARTS.CACHE_OUTPUT, args.train_url)
-
- if __name__=='__main__':
- main(args)
-
-
-
-
-
-
-
-
-
-
-
|