|
- # Copyright 2022 Huawei Technologies Co., Ltd
- # Copyright 2022 Aerospace Information Research Institute,
- # Chinese Academy of Sciences.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """eval of swintransformerv2"""
- import os
- import argparse
- from mindspore.communication.management import init, get_rank
- from easydict import EasyDict as edict
- import logging
- from mindspore.train.model import Model
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore import nn
- from swintransformerv2.datasets.build_dataset import build_dataset
- from swintransformerv2.models import build_eval_engine
- from swintransformerv2.parallel_config import build_parallel_config
- from swintransformerv2.models.swin_transformer_v2 import build_swin_v2
- from swintransformerv2.models import build_eval_engine
- from swintransformerv2.parallel_config import build_parallel_config
- from config import Config, ActionDict
- import logging
-
- def str2bool(b):
- if b.lower() in ["false"]:
- output = False
- elif b.lower() in ["true"]:
- output = True
- else:
- raise Exception("Invalid Bool Value")
- return output
-
- def main(args):
-
- args.logger = logging.getLogger()
- # evaluation dataset
- args.logger.info(".........Build Eval Dataset..........")
- eval_dataset = build_dataset(args, is_pretrain=False, is_train=False)
-
- # build context config
- args.logger.info(".........Build context config..........")
- build_parallel_config(args)
- args.logger.info("context config is:{}".format(args.parallel_config))
- args.logger.info("moe config is:{}".format(args.moe_config))
-
- # build net
- args.logger.info(".........Build Net..........")
- net = build_swin_v2(args)
- eval_engine = build_eval_engine(net, eval_dataset, args)
-
-
- resume_ckpt = args.train_config.resume_ckpt
- if resume_ckpt:
- args.logger.info(".........Load Task Checkpoint..........")
- param_dict = load_checkpoint(os.path.join(resume_ckpt))
- load_param_into_net(net, param_dict)
-
-
- args.logger.info(".........Starting Init Eval Model..........")
- model = Model(net, metrics=eval_engine.metric, eval_network=eval_engine.eval_network)
- eval_engine.set_model(model)
- # define Model and begin eval
- args.logger.info(".........Starting Eval Model..........")
- eval_engine.eval()
- output = eval_engine.get_result()
- last_metric = 'Top1 accuracy={:.6f}'.format(float(output))
- args.logger.info(last_metric)
-
-
-
- if __name__ == "__main__":
- work_path = os.path.dirname(os.path.abspath(__file__))
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default= '/cache/data/')
- parser.add_argument('--ckpt_url',
- help='model to save/load',
- default= '/cache/checkpoint.ckpt')
-
- parser.add_argument('--train_url',
- help='output folder to save/load',
- default= '/cache/output/')
-
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
-
- parser.add_argument('--device_num', default=None, type=int, help='device num')
-
- parser.add_argument(
- '--config',
- default=os.path.join(work_path, "config path"),
- help='YAML config files')
- parser.add_argument('--device_id', default=None, type=int, help='device id')
- parser.add_argument('--seed', default=None, type=int, help='random seed')
- parser.add_argument('--batch_size', default=None, type=int, help='batch size')
- parser.add_argument('--use_parallel', default=None, type=str2bool, help='whether use parallel mode')
- parser.add_argument('--eval_path', default=None, type=str, help='checkpoint path for eval')
-
- args_ = parser.parse_args()
- config = Config(os.path.join(work_path, args_.config))
-
- data_dir = '/cache/data'
- result_dir = '/cache/result'
- ckpt_dir = '/cache/checkpoint_81.ckpt'
-
-
-
- if config.enable_modelarts:
- import moxing as mox
- if not os.path.exists(data_dir):
- os.makedirs(data_dir, exist_ok=True)
- print(f'successfully os.makedirs {data_dir}')
- if not os.path.exists(result_dir):
- os.makedirs(result_dir, exist_ok=True)
- print(f'successfully os.makedirs {result_dir}')
-
- mox.file.copy_parallel(src_url= os.path.join(args_.data_url, "imagenet") , dst_url= data_dir)
- print("Successfully Download {} to {}".format(args_.data_url, data_dir))
- mox.file.copy(args_.ckpt_url, ckpt_dir)
- print("Successfully Download {} to {}".format(args_.ckpt_url,ckpt_dir))
- config.train_config.resume_ckpt = ckpt_dir
-
-
- if args_.device_id is not None:
- config.context.device_id = args_.device_id
- if args_.seed is not None:
- config.seed = args_.seed
- if args_.use_parallel is not None:
- config.use_parallel = args_.use_parallel
- if args_.eval_path is not None:
- config.train_config.resume_ckpt = args_.eval_path
- if args_.batch_size is not None:
- config.train_config.batch_size = args_.batch_size
-
- if config.finetune_dataset.eval_offset < 0:
- config.finetune_dataset.eval_offset = config.train_config.epoch % config.finetune_dataset.eval_interval
-
- config.aicc_config.obs_path = args_.data_url
-
-
-
- main(config)
|