|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """Eval"""
-
- import os
- import time
-
- import cv2
- import numpy as np
- from mindspore import DatasetHelper, load_checkpoint, context
- from mindspore.nn import Sigmoid
-
- from model_utils.config import base_config
- from src.dataset import create_dataset
- from src.egnet import build_model
-
-
- def main(config):
- context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
- test_dataset, dataset = create_dataset(config.test_batch_size, mode="test", num_thread=config.num_thread,
- test_mode=config.test_mode, sal_mode=config.sal_mode,
- test_path=config.test_path, test_fold=config.test_fold)
- evaluate(test_dataset, config, dataset)
-
-
- class Metric:
- """
- for metric
- """
- def __init__(self):
- self.epsilon = 1e-4
- self.beta = 0.3
- self.thresholds = 256
- self.mae = 0
- self.max_f = 0
- self.precision = np.zeros(self.thresholds)
- self.recall = np.zeros(self.thresholds)
- self.q = 0
- self.cnt = 0
-
- def update(self, pred, gt):
- assert pred.shape == gt.shape
- pred = pred.astype(np.float32)
- gt = gt.astype(np.float32)
- norm_pred = pred / 255.0
- norm_gt = gt / 255.0
- self.compute_mae(norm_pred, norm_gt)
- self.compute_precision_and_recall(pred, gt)
- self.compute_s_measure(norm_pred, norm_gt)
- self.cnt += 1
-
- def print_result(self):
- f_measure = (1 + self.beta) * (self.precision * self.recall) / (self.beta * self.precision + self.recall)
- argmax = np.argmax(f_measure)
- print("Max F-measure:", f_measure[argmax] / self.cnt)
- print("Precision: ", self.precision[argmax] / self.cnt)
- print("Recall: ", self.recall[argmax] / self.cnt)
- print("MAE: ", self.mae / self.cnt)
- print("S-measure: ", self.q / self.cnt)
-
- def compute_precision_and_recall(self, pred, gt):
- """
- compute the precision and recall for pred
- :param pred: predict image
- :param gt: ground truth image
- :return: None
- """
- for th in range(self.thresholds):
- a = np.zeros_like(pred).astype(np.int32)
- b = np.zeros_like(pred).astype(np.int32)
- a[pred > th] = 1
- a[pred <= th] = 0
- b[gt > th / self.thresholds] = 1
- b[gt <= th / self.thresholds] = 0
- ab = np.sum(np.bitwise_and(a, b))
- a_sum = np.sum(a)
- b_sum = np.sum(b)
- self.precision[th] += (ab + self.epsilon) / (a_sum + self.epsilon)
- self.recall[th] += (ab + self.epsilon) / (b_sum + self.epsilon)
-
- def compute_mae(self, pred, gt):
- """
- compute mean average error
- :param pred: predict image
- :param gt: ground truth image
- :return: None
- """
- self.mae += np.abs(pred - gt).mean()
-
- def compute_s_measure(self, pred, gt):
- """
- compute s measure score
- :param pred: predict image
- :param gt: ground truth image
- :return: None
- """
- alpha = 0.5
- y = gt.mean()
- if y == 0:
- x = pred.mean()
- q = 1.0 - x
- elif y == 1:
- x = pred.mean()
- q = x
- else:
- gt[gt >= 0.5] = 1
- gt[gt < 0.5] = 0
- q = alpha * self._s_object(pred, gt) + (1 - alpha) * self._s_region(pred, gt)
- if q < 0 or np.isnan(q):
- q = 0
- self.q += q
-
- def _s_object(self, pred, gt):
- """
- score of object
- :param pred: predict image
- :param gt: ground truth image
- :return: score
- """
- fg = np.where(gt == 0, np.zeros_like(pred), pred)
- bg = np.where(gt == 1, np.zeros_like(pred), 1 - pred)
- o_fg = self._object(fg, gt)
- o_bg = self._object(bg, 1 - gt)
- u = gt.mean()
- q = u * o_fg + (1 - u) * o_bg
- return q
-
- @staticmethod
- def _object(pred, gt):
- """
- compute score of object
- :param pred:predict image
- :param gt:ground truth image
- :return:score
- """
- temp = pred[gt == 1]
- if temp.size == 0:
- return 0
- x = temp.mean()
- sigma_x = temp.std()
- score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20)
- return score
-
- def _s_region(self, pred, gt):
- """
- compute score of region
- :param pred: predict image
- :param gt: ground truth image
- :return: score
- """
- x, y = self._centroid(gt)
- gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divide_gt(gt, x, y)
- p1, p2, p3, p4 = self._divide_prediction(pred, x, y)
- q1 = self._ssim(p1, gt1)
- q2 = self._ssim(p2, gt2)
- q3 = self._ssim(p3, gt3)
- q4 = self._ssim(p4, gt4)
- q = w1 * q1 + w2 * q2 + w3 * q3 + w4 * q4
- return q
-
- @staticmethod
- def _divide_gt(gt, x, y):
- """
- divide ground truth image
- :param gt: ground truth image
- :param x: center.x
- :param y: center.y
- :return: ivided ground truth image and weightd
- """
- h, w = gt.shape[-2:]
- area = h * w
- gt = gt.reshape(h, w)
- lt = gt[:y, :x]
- rt = gt[:y, x:w]
- lb = gt[y:h, :x]
- rb = gt[y:h, x:w]
- x = x.astype(np.float32)
- y = y.astype(np.float32)
- w1 = x * y / area
- w2 = (w - x) * y / area
- w3 = x * (h - y) / area
- w4 = 1 - w1 - w2 - w3
- return lt, rt, lb, rb, w1, w2, w3, w4
-
- @staticmethod
- def _divide_prediction(pred, x, y):
- """
- divide predict image
- :param pred: predict image
- :param x: center.x
- :param y: center.y
- :return: divided predict image and weight
- """
- h, w = pred.shape[-2:]
- pred = pred.reshape(h, w)
- lt = pred[:y, :x]
- rt = pred[:y, x:w]
- lb = pred[y:h, :x]
- rb = pred[y:h, x:w]
- return lt, rt, lb, rb
-
- @staticmethod
- def _ssim(pred, gt):
- """
- structural similarity
- :param pred:predict image
- :param gt:ground truth image
- :return:similarity score
- """
- gt = gt.astype(np.float32)
- h, w = pred.shape[-2:]
- n = h * w
- x = pred.mean()
- y = gt.mean()
- sigma_x2 = ((pred - x) * (pred - x)).sum() / (n - 1 + 1e-20)
- sigma_y2 = ((gt - y) * (gt - y)).sum() / (n - 1 + 1e-20)
- sigma_xy = ((pred - x) * (gt - y)).sum() / (n - 1 + 1e-20)
-
- alpha = 4 * x * y * sigma_xy
- beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
-
- if alpha != 0:
- q = alpha / (beta + 1e-20)
- elif alpha == 0 and beta == 0:
- q = 1.0
- else:
- q = 0
- return q
-
- @staticmethod
- def _centroid(gt):
- """
- compute center of ground truth image
- :param gt: ground truth image
- :return: center (x,y)
- """
- rows, cols = gt.shape[-2:]
- gt = gt.reshape(rows, cols)
- if gt.sum() == 0:
- x = np.eye(1) * round(cols / 2)
- y = np.eye(1) * round(rows / 2)
- else:
- total = gt.sum()
-
- i = np.arange(0, cols).astype(np.float32)
- j = np.arange(0, rows).astype(np.float32)
- x = np.round((gt.sum(axis=0) * i).sum() / total)
- y = np.round((gt.sum(axis=1) * j).sum() / total)
- return x.astype(np.int64), y.astype(np.int64)
-
-
- def evaluate(test_ds, config, dataset):
- """build network"""
- model = build_model(config.base_model)
- # Load pretrained model
- load_checkpoint(config.model, net=model)
- print(f"Loading pre-trained model from {config.model}...")
- sigmoid = Sigmoid()
- # test phase
- test_save_name = config.test_save_name + config.base_model
- test_fold = config.test_fold
- if not os.path.exists(os.path.join(test_fold, test_save_name)):
- os.makedirs(os.path.join(test_fold, test_save_name), exist_ok=True)
- dataset_helper = DatasetHelper(test_ds, epoch_num=1, dataset_sink_mode=False)
- time_t = 0.0
-
- metric = Metric()
- has_label = False
- for i, data_batch in enumerate(dataset_helper):
- sal_image, sal_label, name_index = data_batch[0], data_batch[1], data_batch[2]
- name = dataset.image_list[name_index[0].asnumpy().astype(np.int32)]
- save_file = os.path.join(test_fold, test_save_name, name[:-4] + "_sal.png")
- directory, _ = os.path.split(save_file)
- if not os.path.exists(directory):
- os.makedirs(directory, exist_ok=True)
- time_start = time.time()
- _, _, up_sal_final = model(sal_image)
- time_end = time.time()
- time_t += time_end - time_start
- pred = sigmoid(up_sal_final[-1]).asnumpy().squeeze() * 255
-
- if sal_label is not None:
- has_label = True
- sal_label = sal_label.asnumpy().squeeze() * 255
- pred = np.round(pred).astype(np.uint8)
- metric.update(pred, sal_label)
- cv2.imwrite(save_file, pred)
- print(f"process image index {i} done")
-
- print(f"--- {time_t} seconds ---")
- if has_label:
- metric.print_result()
-
-
- if __name__ == "__main__":
- main(base_config)
|