|
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import os
-
- import numpy as np
- import time
- import paddle
- import paddle.nn.functional as F
-
- from paddleseg.utils import metrics, TimeAverager, calculate_eta, logger, progbar
- from core import infer_crop
-
- np.set_printoptions(suppress=True)
-
-
- def evaluate(model,
- eval_dataset,
- aug_eval=False,
- scales=1.0,
- flip_horizontal=True,
- flip_vertical=False,
- is_slide=False,
- stride=None,
- crop_size=None,
- num_workers=0,
- print_detail=True):
- """
- Launch evalution.
-
- Args:
- model(nn.Layer): A sementic segmentation model.
- eval_dataset (paddle.io.Dataset): Used to read and process validation datasets.
- aug_eval (bool, optional): Whether to use mulit-scales and flip augment for evaluation. Default: False.
- scales (list|float, optional): Scales for augment. It is valid when `aug_eval` is True. Default: 1.0.
- flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_eval` is True. Default: True.
- flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_eval` is True. Default: False.
- is_slide (bool, optional): Whether to evaluate by sliding window. Default: False.
- stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
- It should be provided when `is_slide` is True.
- crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
- It should be provided when `is_slide` is True.
- num_workers (int, optional): Num workers for data loader. Default: 0.
- print_detail (bool, optional): Whether to print detailed information about the evaluation process. Default: True.
-
- Returns:
- float: The mIoU of validation datasets.
- float: The accuracy of validation datasets.
- """
- model.eval()
- nranks = paddle.distributed.ParallelEnv().nranks
- local_rank = paddle.distributed.ParallelEnv().local_rank
- if nranks > 1:
- # Initialize parallel environment if not done.
- if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
- ):
- paddle.distributed.init_parallel_env()
- batch_sampler = paddle.io.DistributedBatchSampler(
- eval_dataset, batch_size=1, shuffle=False, drop_last=False)
- loader = paddle.io.DataLoader(
- eval_dataset,
- batch_sampler=batch_sampler,
- num_workers=num_workers,
- return_list=True, )
-
- total_iters = len(loader)
- intersect_area_all = 0
- pred_area_all = 0
- label_area_all = 0
-
- if print_detail:
- logger.info("Start evaluating (total_samples={}, total_iters={})...".
- format(len(eval_dataset), total_iters))
- progbar_val = progbar.Progbar(target=total_iters, verbose=1)
- reader_cost_averager = TimeAverager()
- batch_cost_averager = TimeAverager()
- batch_start = time.time()
- with paddle.no_grad():
- for iter, data in enumerate(loader):
- reader_cost_averager.record(time.time() - batch_start)
- preds = []
- label = data[3].astype('int64')
- for ii in range(3):
- im = data[ii]
- ori_shape = im.shape[-2:]
- if aug_eval:
- pred = infer_crop.aug_inference(
- model,
- im,
- ori_shape=ori_shape,
- transforms=eval_dataset.transforms.transforms,
- scales=scales,
- flip_horizontal=flip_horizontal,
- flip_vertical=flip_vertical,
- is_slide=is_slide,
- stride=stride,
- crop_size=crop_size)
- else:
- pred = infer_crop.inference(
- model,
- im,
- ori_shape=ori_shape,
- transforms=eval_dataset.transforms.transforms,
- is_slide=is_slide,
- stride=stride,
- crop_size=crop_size)
- preds.append(pred)
-
- left_ensem = (
- preds[0][:, :, :, 640:1280] + preds[1][:, :, :, 0:640]) / 2
- right_ensem = (
- preds[1][:, :, :, 640:1280] + preds[2][:, :, :, 0:640]) / 2
- pred_ensem = paddle.concat(
- [
- preds[0][:, :, :, 0:640], left_ensem, right_ensem,
- preds[2][:, :, :, 640:1280]
- ],
- axis=3)
- pred = paddle.argmax(
- pred_ensem, axis=1, keepdim=True, dtype='int32')
-
- intersect_area, pred_area, label_area = metrics.calculate_area(
- pred,
- label,
- eval_dataset.num_classes,
- ignore_index=eval_dataset.ignore_index)
-
- # Gather from all ranks
- if nranks > 1:
- intersect_area_list = []
- pred_area_list = []
- label_area_list = []
- paddle.distributed.all_gather(intersect_area_list,
- intersect_area)
- paddle.distributed.all_gather(pred_area_list, pred_area)
- paddle.distributed.all_gather(label_area_list, label_area)
-
- # Some image has been evaluated and should be eliminated in last iter
- if (iter + 1) * nranks > len(eval_dataset):
- valid = len(eval_dataset) - iter * nranks
- intersect_area_list = intersect_area_list[:valid]
- pred_area_list = pred_area_list[:valid]
- label_area_list = label_area_list[:valid]
-
- for i in range(len(intersect_area_list)):
- intersect_area_all = intersect_area_all + intersect_area_list[
- i]
- pred_area_all = pred_area_all + pred_area_list[i]
- label_area_all = label_area_all + label_area_list[i]
- else:
- intersect_area_all = intersect_area_all + intersect_area
- pred_area_all = pred_area_all + pred_area
- label_area_all = label_area_all + label_area
-
- batch_cost_averager.record(
- time.time() - batch_start, num_samples=len(label))
- batch_cost = batch_cost_averager.get_average()
- reader_cost = reader_cost_averager.get_average()
-
- if local_rank == 0 and print_detail:
- progbar_val.update(iter + 1, [('batch_cost', batch_cost),
- ('reader cost', reader_cost)])
- reader_cost_averager.reset()
- batch_cost_averager.reset()
- batch_start = time.time()
-
- class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
- label_area_all)
- class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all)
- kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all)
-
- if print_detail:
- logger.info("[EVAL] #Images={} mIoU={:.4f} Acc={:.4f} Kappa={:.4f} ".
- format(len(eval_dataset), miou, acc, kappa))
- logger.info("[EVAL] Class IoU: \n" + str(np.round(class_iou, 4)))
- logger.info("[EVAL] Class Acc: \n" + str(np.round(class_acc, 4)))
- return miou, acc
|