|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """eval"""
-
- from mindspore import Model
- from mindspore import context
- from mindspore import nn
- from mindspore.common import set_seed
-
- from src.args import args
- from src.tools.cell import cast_amp
- from src.tools.criterion import get_criterion, NetWithLoss
- from src.tools.get_misc import get_dataset, set_device, get_model, pretrained, get_train_one_step
- from src.tools.optimizer import get_optimizer
- import moxing as mox
-
- import os
-
- set_seed(args.seed)
-
- ### Copy single dataset from obs to inference image ###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- return
- ### Copy ckpt file from obs to inference image###
- ### To operate on folders, use mox.file.copy_parallel. If copying a file.
- ### Please use mox.file.copy to operate the file, this operation is to operate the file
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
- ### Copy the output result to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
-
- def main():
- result_dir = '/cache/result'
- ckpt_url = '/cache/checkpoint.ckpt'
-
- if not os.path.exists(result_dir):
- os.makedirs(result_dir)
-
-
- ###Copy ckpt file from obs to inference image
- ObsUrlToEnv(args.ckpt_url, ckpt_url)
-
-
- mode = {
- 0: context.GRAPH_MODE,
- 1: context.PYNATIVE_MODE
- }
- context.set_context(mode=mode[args.graph_mode], device_target=args.device_target)
- if args.device_target == "Ascend":
- context.set_context(enable_auto_mixed_precision=True)
- set_device(args)
-
- # get model
- net = get_model(args)
- cast_amp(net)
- criterion = get_criterion(args)
-
- net_with_loss = NetWithLoss(net, criterion)
- # load pretrained model for eval
- pretrained(args, net, ckpt_url)
-
- data = get_dataset(args, training=False)
- batch_num = data.val_dataset.get_dataset_size()
- optimizer = get_optimizer(args, net, batch_num)
- # save a yaml file to read to record parameters
-
- net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
- eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
- eval_indexes = [0, 1, 2]
- eval_metrics = {'Loss': nn.Loss(),
- 'Top1-Acc': nn.Top1CategoricalAccuracy(),
- 'Top5-Acc': nn.Top5CategoricalAccuracy()}
- model = Model(net_with_loss, metrics=eval_metrics,
- eval_network=eval_network,
- eval_indexes=eval_indexes)
- print(f"=> begin eval")
- results = model.eval(data.val_dataset)
- print(f"=> eval results:{results}")
- print(f"=> eval success")
-
- acc_file = 'acc.log'
- acc_file_path = os.path.join(result_dir, acc_file)
- with open(acc_file_path, 'a+') as file:
- file.write("The acc of imagenet: ")
- file.write(results)
-
- ###Copy result data from the local running environment back to obs,
- ###and download it in the inference task corresponding to the Qizhi platform
- EnvToObs(result_dir, args.result_url)
-
-
- if __name__ == '__main__':
- main()
|