|
- import torch, time
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- from medpy import metric
- import torchmetrics
-
- class AbsMetric(object):
- r"""An abstract class for the performance metrics of a task.
-
- Attributes:
- record (list): A list of the metric scores in every iteration.
- bs (list): A list of the number of data in every iteration.
- """
- def __init__(self):
- self.record = []
- self.bs = []
-
- @property
- def update_fun(self, pred, gt):
- r"""Calculate the metric scores in every iteration and update :attr:`record`.
-
- Args:
- pred (torch.Tensor): The prediction tensor.
- gt (torch.Tensor): The ground-truth tensor.
- """
- pass
-
- @property
- def score_fun(self):
- r"""Calculate the final score (when a epoch ends).
-
- Return:
- list: A list of metric scores.
- """
- pass
-
- def reinit(self):
- r"""Reset :attr:`record` and :attr:`bs` (when a epoch ends).
- """
- self.record = []
- self.bs = []
-
- # accuracy
- class AccMetric(AbsMetric):
- r"""Calculate the accuracy.
- """
- def __init__(self):
- super(AccMetric, self).__init__()
-
- def update_fun(self, pred, gt):
- r"""
- """
- pred = F.softmax(pred, dim=-1).max(-1)[1]
- self.record.append(gt.eq(pred).sum().item())
- self.bs.append(pred.size()[0])
-
- def score_fun(self):
- r"""
- """
- return [(sum(self.record)/sum(self.bs))]
-
- # L1 Error
- class L1Metric(AbsMetric):
- r"""Calculate the Mean Absolute Error (MAE).
- """
- def __init__(self):
- super(L1Metric, self).__init__()
-
- def update_fun(self, pred, gt):
- r"""
- """
-
- npred = pred.cpu()
- ngt = gt.cpu()
- npred = npred.detach().numpy()
- ngt = ngt.detach().numpy()
- dice = metric.binary.dc(npred, ngt)
-
- # abs_err = torch.abs(pred - gt)
- # print(type(abs_err))
- # self.record.append(abs_err.item())
- # self.bs.append(pred.size()[0])
-
- self.record.append(dice)
- self.bs.append(pred.size()[0])
-
- def score_fun(self):
- r"""
- """
- records = np.array(self.record)
- batch_size = np.array(self.bs)
- return [(records*batch_size).sum()/(sum(batch_size))]
-
- #PSNR
- class PSNRMetric(AbsMetric):
- r"""Calculate the PEAK SIGNAL-TO-NOISE RATIO (PSNR).
- """
- def __init__(self):
- super(PSNRMetric, self).__init__()
-
- def update_fun(self, pred, gt):
- r"""
- """
- # npred = pred.cpu()
- # ngt = gt.cpu()
-
- # npred = npred.detach().numpy()
- # ngt = ngt.detach().numpy()
- psnr = torchmetrics.PeakSignalNoiseRatio().to(torch.device("cuda", 0))
-
- psnr_sc = psnr(pred, gt)
-
- psnr_sc = psnr_sc.cpu()
- # print(type(psnr_metric),psnr_metric) # <class 'torch.Tensor'> tensor(10.8525)
-
- self.record.append(psnr_sc)
- self.bs.append(pred.size()[0])
-
- def score_fun(self):
- r"""
- """
- records = np.array(self.record)
- batch_size = np.array(self.bs)
- return [(records*batch_size).sum()/(sum(batch_size))]
-
- #ssim
- class SSIMMetric(AbsMetric):
- r"""Calculate the Multi-scale Structural Similarity Index Measure (SSIM).
- """
- def __init__(self):
- super(SSIMMetric, self).__init__()
-
- def update_fun(self, pred, gt):
- r"""
- """
-
- #在GPU上面直接运行
- ssim = torchmetrics.MultiScaleStructuralSimilarityIndexMeasure().to(torch.device("cuda", 0))
- ssim_sc = ssim(pred, gt) # <class 'torch.Tensor'> tensor(7.4112, device='cuda:0')
-
- ssim_sc = ssim_sc.cpu()
-
- self.record.append(ssim_sc)
- self.bs.append(pred.size()[0])
-
- def score_fun(self):
- r"""
- """
- records = np.array(self.record)
- batch_size = np.array(self.bs)
- return [(records*batch_size).sum()/(sum(batch_size))]
|