|
- import os
- import zipfile
- import argparse
-
- import matplotlib.pyplot as plt
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context, save_checkpoint
- from mindspore.train.callback import TimeMonitor, LossMonitor, ModelCheckpoint, CheckpointConfig
- from mindspore.context import ParallelMode
- from mindspore.communication import get_rank, init, get_group_size
- import moxing as mox
-
- from model import Resnet50
- from loss import MultiLabelLoss, NetWithLoss
- from dataset import create_plant_dataset
- import config as cfg
- from moxing_adapter import sync_data
- from val import EvaluateCallBack
-
- parser = argparse.ArgumentParser(description='Train keypoints network')
- parser.add_argument('--train_url', required=False,
- default=None, help='Location of training outputs.')
- parser.add_argument('--multi_data_url', required=False,
- default=None, help='Location of data.')
- parser.add_argument('--data_url', required=False,
- default=None, help='Location of data.')
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend'])
- parser.add_argument('--ckpt_url', required=False, default=None, help='Location of pretrained model.')
- args = parser.parse_args()
-
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
- device_id=int(os.getenv("DEVICE_ID")))
- context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True,
- search_mode="recursive_programming")
- init()
-
- if get_rank()==0:
- # 初始化数据集存放目录
- if not os.path.exists(cfg.CACHE_INPUT):
- os.makedirs(cfg.CACHE_INPUT)
- # 初始化模型存放目录
- if not os.path.exists(cfg.CACHE_OUTPUT):
- os.makedirs(cfg.CACHE_OUTPUT)
-
- false = "False"
- print(eval(args.multi_data_url)[0])
- data_url = eval(args.multi_data_url)[0]["dataset_url"]
- # data_url = args.data_url
- sync_data(data_url, cfg.CACHE_INPUT + '/dataset.zip')
- # sync_data(data_url, cfg.CACHE_INPUT)
-
- zip_file = zipfile.ZipFile(cfg.CACHE_INPUT + '/dataset.zip')
- zip_list = zip_file.namelist()
- for f in zip_list:
- zip_file.extract(f, cfg.CACHE_INPUT)
- zip_file.close()
- print(os.listdir(cfg.DATASET_PATH))
-
- train_ds = create_plant_dataset('train', get_group_size())
- test_ds = create_plant_dataset('test')
- val_ds = create_plant_dataset('val')
- dataset_size = train_ds.get_dataset_size()
- print('dataset_size:', dataset_size)
-
- net = Resnet50()
- milestone = [i*dataset_size for i in cfg.epoch]
- print(milestone)
- pw_lr = nn.piecewise_constant_lr(milestone, cfg.lr)
- optimizer = nn.Adam(net.trainable_params(), pw_lr, cfg.weight_decay)
- # optimizer = nn.Adam(net.trainable_params(), 1e-4, cfg.weight_decay)
-
- loss = MultiLabelLoss()
- netWithLoss = NetWithLoss(net, loss)
-
- time_cb = TimeMonitor()
- loss_cb = LossMonitor()
- eval_cb = EvaluateCallBack(model=net, val_ds=val_ds, train_ds=train_ds)
- ckpt_config = CheckpointConfig(save_checkpoint_steps=3*dataset_size)
- ckpt_save_dir = cfg.CACHE_OUTPUT + "/rank_" + str(get_rank()) + "/"
- ckpoint_cb = ModelCheckpoint(prefix="Resnet50_", directory=ckpt_save_dir, config=ckpt_config)
- callback_list = [time_cb, loss_cb, eval_cb, ckpoint_cb]
-
- model = ms.Model(netWithLoss, optimizer=optimizer)
- print("************ Start training now ************")
- # model.train(2, train_ds, callback_list)
- model.train(cfg.total_epoch, train_ds, callback_list)
- print("************ Training complete ************")
-
- save_checkpoint(net, cfg.CACHE_OUTPUT+ '/' + str(get_rank()) + '_trained_model_param.ckpt')
- plt.plot(list(range(1, cfg.total_epoch+1)), cfg.acc_val_iter)
- plt.plot(list(range(1, cfg.total_epoch+1)), cfg.acc_train_iter)
- plt.xlabel("Epoch")
- plt.ylabel("Accuracy")
- plt.legend(labels=["Val", "Train"])
- # plt.plot(list(range(1, 2+1)), cfg.acc_iter)
- plt.savefig(cfg.CACHE_OUTPUT+ '/' + str(get_rank()) + 'val_acc_iter.png')
|