|
- import os
- import argparse
- import moxing as mox
- import numpy as np
- import mindspore.nn as nn
- from mindspore import Tensor, export, load_checkpoint
- from mindspore.train.serialization import load
-
- parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=256,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder')
-
- parser.add_argument('--train_url',
- help='model folder to save/load')
-
-
- workroot = '/home/work/user-job-dir' # 训练任务使用该参数
- print('workroot:' + workroot)
-
-
- if __name__ == "__main__":
- args = parser.parse_args()
- print('args:')
- print(args)
- data_dir = workroot + '/data' #数据集存放路径
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
-
- train_dir = workroot + '/model' #模型存放路径
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
-
- obs_train_url = args.train_url
- obs_data_url = args.data_url
- #将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url,data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
-
- model_file = data_dir + '/' + args.model
- print(model_file)
- graph = load(model_file)
- net = nn.GraphCell(graph)
- print(net)
-
- suffix = args.model.rindex(".")
- out_file = train_dir + '/' + args.model + ".onnx"
- if suffix!=-1 :
- out_file = train_dir + '/' + args.model[0:suffix] + ".onnx"
- print(out_file)
- input_np = np.random.uniform(0.0, 1.0, size=[args.n, args.c, args.h, args.w]).astype(np.float32)
-
- # 保存net文件到当前目录下
- export(net, Tensor(input_np), file_name=out_file, file_format='ONNX')
-
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
-
|