|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import argparse
- import os
- import time
- import glob
- import numpy as np
- import PIL.Image as Image
- from tabulate import tabulate
-
- ## Params
- parser = argparse.ArgumentParser()
- parser.add_argument('--label_path', type=str
- , help='directory of dataset label')
- parser.add_argument('--output_path', default=None, type=str
- , help='path of the predict files that generated by the model')
- parser.add_argument('--image_height', default=768, type=int, help='image_height')
- parser.add_argument('--image_width', default=768, type=int, help='image_width')
- parser.add_argument('--save_mask', default=0, type=int, help='0 for False, 1 for True')
- parser.add_argument('--mask_result_path', default='./mask_result', type=str
- , help='the folder to save the semantic mask images')
-
- args = parser.parse_args()
-
- cityspallete = [
- 128, 64, 128,
- 244, 35, 232,
- 70, 70, 70,
- 102, 102, 156,
- 190, 153, 153,
- 153, 153, 153,
- 250, 170, 30,
- 220, 220, 0,
- 107, 142, 35,
- 152, 251, 152,
- 0, 130, 180,
- 220, 20, 60,
- 255, 0, 0,
- 0, 0, 142,
- 0, 0, 70,
- 0, 60, 100,
- 0, 80, 100,
- 0, 0, 230,
- 119, 11, 32,
- ]
- classes = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
- 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
- 'truck', 'bus', 'train', 'motorcycle', 'bicycle')
-
- class SegmentationMetric():
- """Computes pixAcc and mIoU metric scores
- """
- def __init__(self, nclass):
- super(SegmentationMetric, self).__init__()
- self.nclass = nclass
- self.reset()
-
- def update(self, preds, labels):
- """Updates the internal evaluation result.
- Parameters
- ----------
- labels : 'NumpyArray' or list of `NumpyArray`
- The labels of the data.
- preds : 'NumpyArray' or list of `NumpyArray`
- Predicted values.
- """
- def evaluate_worker(self, pred, label):
- correct, labeled = batch_pix_accuracy(pred, label)
- inter, union = batch_intersection_union(pred, label, self.nclass)
- self.total_correct += correct
- self.total_label += labeled
- self.total_inter += inter
- self.total_union += union
- evaluate_worker(self, preds, labels)
-
- def get(self, return_category_iou=False):
- """Gets the current evaluation result.
- Returns
- -------
- metrics : tuple of float
- pixAcc and mIoU
- """
- # remove np.spacing(1)
- pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label)
- IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
- mIoU = IoU.mean().item()
- if return_category_iou:
- return pixAcc, mIoU, IoU
- return pixAcc, mIoU
-
- def reset(self):
- """Resets the internal evaluation result to initial state."""
- self.total_inter = np.zeros(self.nclass)
- self.total_union = np.zeros(self.nclass)
- self.total_correct = 0
- self.total_label = 0
-
- def batch_pix_accuracy(output, target):
- """PixAcc"""
- # inputs are numpy array, output 4D NCHW where 'C' means label classes, target 3D NHW
- predict = np.argmax(output.astype(np.int64), 1) + 1
- target = target.astype(np.int64) + 1
- pixel_labeled = (target > 0).sum()
- pixel_correct = ((predict == target) * (target > 0)).sum()
- assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
- return pixel_correct, pixel_labeled
-
- def batch_intersection_union(output, target, nclass):
- """mIoU"""
- # inputs are numpy array, output 4D, target 3D
- mini = 1
- maxi = nclass
- nbins = nclass
- predict = np.argmax(output.astype(np.float32), 1) + 1
- target = target.astype(np.float32) + 1
-
- predict = predict.astype(np.float32) * (target > 0).astype(np.float32)
- intersection = predict * (predict == target).astype(np.float32)
- # areas of intersection and union
- # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
- area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
- area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
- area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
- area_union = area_pred + area_lab - area_inter
- assert (area_inter > area_union).sum() == 0, "Intersection area should be smaller than Union area"
- return area_inter.astype(np.float32), area_union.astype(np.float32)
-
- def cal_mIoU():
- file_list = glob.glob(os.path.join(args.label_path, '*'))
- start_time = time.time()
- metric = SegmentationMetric(19)
- metric.reset()
- if args.save_mask and not os.path.exists(args.mask_result_path):
- os.makedirs(args.mask_result_path)
- for index, file in enumerate(sorted(file_list)):
- label = np.fromfile(file, dtype=np.int32)
- label = label.reshape(args.image_height, args.image_width)
-
- filename = file.split(os.sep)[-1][:-10] # get the name of image file
- predict_path = os.path.join(args.output_path, filename + "_img_0.bin")
- predict = np.fromfile(predict_path, dtype=np.float32)
- predict = predict.reshape(1, 19, args.image_height, args.image_width)
- metric.update(predict, label)
- pixAcc, mIoU = metric.get()
- print("[EVAL] Sample: {:d}, pixAcc: {:.3f}, mIoU: {:.3f}".format(index + 1, pixAcc * 100, mIoU * 100))
-
- if args.save_mask:
- output = np.argmax(predict[0], axis=0)
- out_img = Image.fromarray(output.astype('uint8'))
- out_img.putpalette(cityspallete)
- outname = str(filename) + '.png'
- out_img.save(os.path.join(args.mask_result_path, outname))
-
- pixAcc, mIoU, category_iou = metric.get(return_category_iou=True)
- print('End validation pixAcc: {:.3f}, mIoU: {:.3f}'.format(pixAcc * 100, mIoU * 100))
- txtName = os.path.join(args.mask_result_path, "eval_results.txt")
- with open(txtName, "w") as f:
- string = 'validation pixAcc:' + str(pixAcc * 100) + ', mIoU:' + str(mIoU * 100)
- f.write(string)
- f.write('\n')
- headers = ['class id', 'class name', 'iou']
- table = []
- for i, cls_name in enumerate(classes):
- table.append([cls_name, category_iou[i]])
- string = 'class name: ' + cls_name + ' iou: ' + str(category_iou[i]) + '\n'
- f.write(string)
- print('Category iou: \n {}'.format(tabulate(table, headers, \
- tablefmt='grid', showindex="always", numalign='center', stralign='center')))
- time_used = time.time() - start_time
- print("Time cost:"+str(time_used)+" seconds!")
-
- if __name__ == '__main__':
- cal_mIoU()
|