|
- import os
- import time
- import moxing as mox
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import Model, context
- from mindspore.communication import init
- from mindspore.context import ParallelMode
-
- from mindcv.models import create_model
- from mindcv.data import create_dataset, create_transforms, create_loader
- from mindcv.loss import create_loss
- from config import parse_args
- from mindcv.utils.utils import check_batch_size
-
- ### Copy single dataset from obs to inference image ###
- def ObsToEnv(obs_data_url, data_dir):
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url, data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e))
- return
- ### Copy ckpt file from obs to inference image###
- ### To operate on folders, use mox.file.copy_parallel. If copying a file.
- ### Please use mox.file.copy to operate the file, this operation is to operate the file
- def ObsUrlToEnv(obs_ckpt_url, ckpt_url):
- try:
- mox.file.copy(obs_ckpt_url, ckpt_url)
- print("Successfully Download {} to {}".format(obs_ckpt_url,ckpt_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(obs_ckpt_url, ckpt_url) + str(e))
- return
- ### Copy the output result to obs###
- def EnvToObs(train_dir, obs_train_url):
- try:
- mox.file.copy_parallel(train_dir, obs_train_url)
- print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
- return
-
- ### --data_url,--ckpt_url,--result_url,--device_target,These 4 parameters must be defined first in a inference task,
- ### otherwise an error will be reported.
- ### There is no need to add these parameters to the running parameters of the Qizhi platform,
- ### because they are predefined in the background, you only need to define them in your code.
-
- def validate(args):
- ms.set_context(mode=args.mode)
-
- # create dataset
- dataset_eval = create_dataset(
- name=args.dataset,
- root=args.data_dir,
- split=args.val_split,
- num_parallel_workers=args.num_parallel_workers,
- download=args.dataset_download)
-
-
- # create transform
- transform_list = create_transforms(
- dataset_name=args.dataset,
- is_training=False,
- image_resize=args.image_resize,
- crop_pct=args.crop_pct,
- interpolation=args.interpolation,
- mean=args.mean,
- std=args.std
- )
-
- # read num clases
- num_classes = dataset_eval.num_classes() if args.num_classes==None else args.num_classes
-
- # check batch size
- batch_size = check_batch_size(dataset_eval.get_dataset_size(), args.batch_size)
-
- # load dataset
- loader_eval = create_loader(
- dataset=dataset_eval,
- batch_size=batch_size,
- drop_remainder=False,
- is_training=False,
- mixup=args.mixup,
- cutmix=args.cutmix,
- transform=transform_list,
- num_parallel_workers=args.num_parallel_workers,
- )
-
- # create model
- network = create_model(model_name=args.model,
- num_classes=num_classes,
- drop_rate=args.drop_rate,
- drop_path_rate=args.drop_path_rate,
- pretrained=args.pretrained,
- checkpoint_path=args.ckpt_path)
- network.set_train(False)
-
- # create loss
- loss = create_loss(name=args.loss,
- reduction=args.reduction,
- label_smoothing=args.label_smoothing,
- aux_factor=args.aux_factor)
-
- # Define eval metrics.
- if num_classes >= 5:
- eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy(),
- 'Top_5_Accuracy': nn.Top5CategoricalAccuracy(),
- 'loss': nn.metrics.Loss()
- }
- else:
- eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy(),
- 'loss': nn.metrics.Loss()}
-
- # init model
- model = Model(network, loss_fn=loss, metrics=eval_metrics)
-
- # validate
- result = model.eval(loader_eval)
- print(result)
-
-
- if __name__ == '__main__':
- args = parse_args()
- ###Initialize the data and result directories in the inference image###
- data_dir = '/cache/data'
- result_dir = '/cache/result'
- ckpt_url = '/cache/checkpoint.ckpt'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(result_dir):
- os.makedirs(result_dir)
-
- ###Copy dataset from obs to inference image
- ObsToEnv(args.data_url, data_dir)
- data_url = args.data_url
- local_data_path = '/cache/dataset'
- os.makedirs(local_data_path, exist_ok=True)
- from moxing_adapter import sync_data
- sync_data(data_url, local_data_path, threads=256)
- print(f"local_data_path:{os.listdir(local_data_path)}")
- if "imagenet" in os.listdir(local_data_path):
- local_data_path = os.path.join(local_data_path, "imagenet")
- args.data_dir = local_data_path
-
- ###Copy ckpt file from obs to inference image
- ObsUrlToEnv(args.ckpt_url, ckpt_url)
-
- validate(args)
-
- ###Copy result data from the local running environment back to obs,
- ###and download it in the inference task corresponding to the Qizhi platform
- EnvToObs(result_dir, args.result_url)
|