|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- ##############test resnet34 example on imagenet2012#################
- python eval.py
- """
- import os
- import argparse
- import moxing as mox
- from mindspore import context
- from mindspore.common import set_seed
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from src.cross_entropy_smooth import CrossEntropySmooth
- import mindspore.nn as nn
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.train import Model
- from mindspore.nn.metrics import Accuracy
- from mindspore import Tensor
- import numpy as np
- from glob import glob
-
- from src.VoVNet import _VoVNet19_slim_eSE as vovnet
- from src.config import config
- from src.dataset import create_dataset
-
- ### Copy single dataset from obs to inference image ###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- return
- ### Copy ckpt file from obs to inference image###
- ### To operate on folders, use mox.file.copy_parallel. If copying a file.
- ### Please use mox.file.copy to operate the file, this operation is to operate the file
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
- ### Copy the output result to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
- ### --data_url,--ckpt_url,--result_url,--device_target,These 4 parameters must be defined first in a inference task,
- ### otherwise an error will be reported.
- ### There is no need to add these parameters to the running parameters of the Qizhi platform,
- ### because they are predefined in the background, you only need to define them in your code.
- parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
- parser.add_argument('--data_url',
- type=str,
- default= '/cache/data/',
- help='path where the dataset is saved')
- parser.add_argument('--ckpt_url',
- help='model to save/load',
- default= '/cache/checkpoint.ckpt')
- parser.add_argument('--result_url',
- help='result folder to save/load',
- default= '/cache/result/')
- parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
-
- set_seed(1)
-
- if __name__ == "__main__":
- args, unknown = parser.parse_known_args()
-
- ###Initialize the data and result directories in the inference image###
- data_dir = '/cache/data'
- result_dir = '/cache/result'
- ckpt_url = '/cache/checkpoint.ckpt'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(result_dir):
- os.makedirs(result_dir)
-
- ###Copy dataset from obs to inference image
- ObsToEnv(args.data_url, data_dir)
-
- ###Copy ckpt file from obs to inference image
- ObsUrlToEnv(args.ckpt_url, ckpt_url)
-
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
-
- net = vovnet(num_class=config.class_num)
- # load checkpoint
- param_dict = load_checkpoint(os.path.join(ckpt_url))
- load_param_into_net(net, param_dict)
- net.set_train(False)
-
- # create dataset
- data_path = os.path.join(data_dir, 'imagenet/val')
- dataset = create_dataset(dataset_path=data_path, do_train=False, batch_size=config.batch_size)
- if not config.use_label_smooth:
- config.label_smooth_factor = 0.0
- loss = CrossEntropySmooth(sparse=True, reduction='mean',
- smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
- # define model
- model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
- # eval model
- result = model.eval(dataset)
- print("result:", result)
|