|
- # Copyright 2021 Pengcheng Laboratory
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import argparse
-
- import suwen.utils as su
- from suwen.algorithm.nets.resnet import resnet50, resnet101
- from suwen.engine import Engine
- from suwen.losses import SoftmaxCrossEntropyWithLogits
- from suwen.losses.cross_entropy_with_logits import CrossEntropySmooth
-
- parser = argparse.ArgumentParser(description = 'Image classification')
- parser.add_argument('--net', type = str, default = "resnet50", help = 'Resnet Model, either resnet50 or resnet101.',
- choices = ['resnet50', 'resnet101'])
- parser.add_argument('--dataset', type = str, default = 'covid', help = 'Dataset, either cifar10 or imagenet2012, covid.')
- parser.add_argument('--device_id', type = int, default = 0, help = 'Device ID.')
-
- parser.add_argument('--ckpt_path', type = str, default = None, help = 'Checkpoint file path')
- parser.add_argument('--data_path', type = str, default = "./", help = 'Dataset path')
- parser.add_argument('--device_target', type = str, default = 'Ascend', choices = ("Ascend", "GPU", "CPU"),
- help = "Device target, support Ascend, GPU and CPU.")
- parser.add_argument('--class_num', type = int, default = 2, help = 'Number of classes.')
- parser.add_argument('--batch_size', type = int, default = 32, help = 'Batch size for training.')
-
- args_opt = parser.parse_args()
-
- if args_opt.dataset == "cifar10":
- from dataset import create_dataset1 as create_dataset
- elif args_opt.dataset == "imagenet2012":
- from dataset import create_dataset2 as create_dataset
- elif args_opt.dataset == "covid":
- from dataset import create_dataset2 as create_dataset
- else:
- raise KeyError("dataset should be in [cifar10, imagenet2012]. ")
-
- if __name__ == '__main__':
-
- su.initial_context(su.GRAPH_MODE, device_id = args_opt.device_id, device_target = args_opt.device_target,
- save_graphs = False, seed = 1)
-
- dataset = create_dataset(dataset_path = args_opt.data_path, do_train = False, batch_size = args_opt.batch_size,
- target = args_opt.device_target)
- step_size = dataset.get_dataset_size()
-
- if args_opt.net == 'resnet50':
- net = resnet50(class_num = args_opt.class_num)
- else:
- net = resnet101(class_num = args_opt.class_num)
-
- if args_opt.dataset == "imagenet2012":
- if not args_opt.use_label_smooth:
- args_opt.label_smooth_factor = 0.0
- loss = CrossEntropySmooth(sparse = True, reduction = 'mean',
- smooth_factor = args_opt.label_smooth_factor, num_classes = args_opt.class_num)
- else:
- loss = SoftmaxCrossEntropyWithLogits(sparse = True, reduction = 'mean')
-
- engine = Engine(network = net,
- loss_fn = loss,
- metrics = {'top_1_accuracy'})
- net.set_train(False)
- engine.eval(dataset, load_ckpt_path = args_opt.ckpt_path)
|