|
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore.ops import operations as ops
- from mindspore.nn.loss.loss import LossBase
- from xbm.models import beit_base_patch16_224
- from xbm.datasets import create_dataset
- from mindspore.train.model import _transfer_tensor_to_tuple ,RunContext
- import time
- import os
- import pdb
-
- class NetWithLoss(nn.Cell):
- """
- NetWithLoss: Only support Network with Classfication
- """
-
- def __init__(self, model, criterion):
- super(NetWithLoss, self).__init__()
- self.model = model
- self.criterion = criterion
-
- def construct(self, data, label):
- predict = self.model(data)
- loss = self.criterion(predict, label)
- return loss
-
- class SoftTargetCrossEntropy(LossBase):
- """SoftTargetCrossEntropy for MixUp Augment"""
-
- def __init__(self):
- super(SoftTargetCrossEntropy, self).__init__()
- self.mean_ops = ops.ReduceMean(keep_dims=False)
- self.sum_ops = ops.ReduceSum(keep_dims=False)
- self.log_softmax = ops.LogSoftmax()
-
- def construct(self, logit, label):
- logit = ops.Cast()(logit, ms.float32)
- label = ops.Cast()(label, ms.float32)
- loss = self.sum_ops(-label * self.log_softmax(logit), -1)
- return self.mean_ops(loss)
-
-
- if __name__ == "__main__":
-
- device_id = int(os.getenv('DEVICE_ID', '0'))
- rank_size = int(os.getenv('RANK_SIZE', '0'))
- rank_id = int(os.getenv('RANK_ID', '0'))
- ms.context.set_context(mode=ms.context.GRAPH_MODE, device_target='Ascend', device_id=device_id)
-
- net = beit_base_patch16_224()
- params_dict = ms.load_checkpoint("checkpoint/beitv2_base_patch16_224_pt1k_ft21kto1k.ckpt")
- ms.load_param_into_net(net, params_dict)
- net.set_train(False)
-
- # criterion = SoftTargetCrossEntropy()
- criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
- eval_network = nn.WithEvalCell(net, criterion)
-
- dataset = create_dataset(dataset_path = '/data/ILSVRC2012/val',
- do_train = True,
- image_height = 224,
- image_width = 224,
- device_target = 'Ascand',
- batch_size=128, run_distribute=False)
-
- model = ms.Model(net, loss_fn=criterion, metrics={'top_1_accuracy', 'top_5_accuracy'})
- res = model.eval(dataset)
- print("result:", res)
-
-
|