|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import argparse
- import os
- import numpy as np
- import mindspore as ms
- from mindspore import Tensor, Model, save_checkpoint, context
- from mindspore.communication.management import get_rank, init, get_group_size
- from mindspore.context import ParallelMode
- import mindspore.nn as nn
- from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
- from mindspore import load_checkpoint, load_param_into_net
-
- from src.config import config
- from src.loss import CrossEntropyWithWeight, NetWithLoss
- from models.unet3plus import UNet3Plus
- from net import UNETR
- from src.dataset import create_dataset
- from src.evaluation import EvaluateCallBack
-
- config.IS_MODELART = True
- from moxing_adapter import sync_data
-
- parser = argparse.ArgumentParser(description='Train keypoints network')
- parser.add_argument('--train_url', required=False,
- default=None, help='Location of training outputs.')
- parser.add_argument('--multi_data_url', required=False,
- default=None, help='Location of data.')
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'])
- parser.add_argument('--ckpt_url', required=False, default=None, help='Location of pretrained model.')
- args = parser.parse_args()
-
-
-
-
- def main(args):
- context.set_context(mode=context.GRAPH_MODE, device_id = int(os.environ["DEVICE_ID"]), device_target = "Ascend")
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- parameter_broadcast=True,
- auto_parallel_search_mode="recursive_programming")
- init()
- # sync_data(args.ckpt_url, config.MODELARTS.CACHE_INPUT + "/unet3plus.ckpt")
- # net = UNet3Plus(in_channels=config.slice_stack, n_classes=config.num_classes)
- net = UNETR(in_channels=1, n_classes=config.num_classes)
- # param_dict = load_checkpoint(config.MODELARTS.CACHE_INPUT + "/unet3plus.ckpt")
- # load_param_into_net(net, param_dict)
-
- # 初始化数据集存放目录
- if not os.path.exists(config.MODELARTS.CACHE_INPUT):
- os.makedirs(config.MODELARTS.CACHE_INPUT)
- # 初始化模型存放目录
- if not os.path.exists(config.MODELARTS.CACHE_OUTPUT):
- os.makedirs(config.MODELARTS.CACHE_OUTPUT)
-
- data_url = eval(args.multi_data_url)[0]["dataset_url"]
- sync_data(data_url, config.MODELARTS.CACHE_INPUT + '/data.zip')
-
- import zipfile
- zip_file = zipfile.ZipFile(config.MODELARTS.CACHE_INPUT + '/data.zip')
- zip_list = zip_file.namelist()
- for f in zip_list:
- zip_file.extract(f, config.MODELARTS.CACHE_INPUT)
- zip_file.close()
-
- train_dataset = create_dataset(is_train=True)
- val_dataset = create_dataset(is_train=False)
-
- optimizer = nn.Adam(net.trainable_params(), learning_rate=config.TRAIN.lr, weight_decay=config.TRAIN.WD)
- loss = CrossEntropyWithWeight(weights=config.TRAIN.loss_weights)
-
- dataset_size = train_dataset.get_dataset_size()
- time_cb = TimeMonitor(data_size=dataset_size)
- loss_cb = LossMonitor()# per_print_times=dataset_size
- callback_list = [time_cb,loss_cb]
-
- if config.TRAIN.with_eval:
- eval_cb = EvaluateCallBack(model=net,eval_dataset=val_dataset)
- callback_list.append(eval_cb)
-
- net_with_loss = NetWithLoss(net, loss_fn=loss)
- model = Model(network=net_with_loss, optimizer=optimizer, amp_level="O0")
-
- epoch_size = config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH
- print("************ Start training now ************")
- print('start training, epoch size = %d' % epoch_size)
- model.train(epoch_size, train_dataset, dataset_sink_mode=False, callbacks=callback_list)
- print("************ Training complete ************")
-
- save_checkpoint(net, config.MODELARTS.CACHE_OUTPUT+"/trained_model_param.ckpt")
-
-
- if __name__=='__main__':
- main(args)
|