|
- import torch
- from torch.nn.functional import cross_entropy
- from torch.nn.modules.loss import _WeightedLoss
-
-
- EPSILON = 1e-32
-
-
- class LogNLLLoss(_WeightedLoss):
- __constants__ = ['weight', 'reduction', 'ignore_index']
-
- def __init__(self, weight=None, size_average=None, reduce=None, reduction=None,
- ignore_index=-100):
- super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
- self.ignore_index = ignore_index
-
- def forward(self, y_input, y_target):
- # y_input = torch.log(y_input + EPSILON)
- return cross_entropy(y_input, y_target, weight=self.weight,
- ignore_index=self.ignore_index)
-
-
- def classwise_iou(output, gt):
- """
- Args:
- output: torch.Tensor of shape (n_batch, n_classes, image.shape)
- gt: torch.LongTensor of shape (n_batch, image.shape)
- """
- dims = (0, *range(2, len(output.shape)))
- gt = torch.zeros_like(output).scatter_(1, gt[:, None, :], 1)
- intersection = output*gt
- union = output + gt - intersection
- classwise_iou = (intersection.sum(dim=dims).float() + EPSILON) / (union.sum(dim=dims) + EPSILON)
-
- return classwise_iou
-
-
- def classwise_f1(output, gt):
- """
- Args:
- output: torch.Tensor of shape (n_batch, n_classes, image.shape)
- gt: torch.LongTensor of shape (n_batch, image.shape)
- """
-
- epsilon = 1e-20
- n_classes = output.shape[1]
-
- output = torch.argmax(output, dim=1)
- true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float()
- selected = torch.tensor([(output == i).sum() for i in range(n_classes)]).float()
- relevant = torch.tensor([(gt == i).sum() for i in range(n_classes)]).float()
-
- precision = (true_positives + epsilon) / (selected + epsilon)
- recall = (true_positives + epsilon) / (relevant + epsilon)
- classwise_f1 = 2 * (precision * recall) / (precision + recall)
-
- return classwise_f1
-
-
- def make_weighted_metric(classwise_metric):
- """
- Args:
- classwise_metric: classwise metric like classwise_IOU or classwise_F1
- """
-
- def weighted_metric(output, gt, weights=None):
-
- # dimensions to sum over
- dims = (0, *range(2, len(output.shape)))
-
- # default weights
- if weights == None:
- weights = torch.ones(output.shape[1]) / output.shape[1]
- else:
- # creating tensor if needed
- if len(weights) != output.shape[1]:
- raise ValueError("The number of weights must match with the number of classes")
- if not isinstance(weights, torch.Tensor):
- weights = torch.tensor(weights)
- # normalizing weights
- weights /= torch.sum(weights)
-
- classwise_scores = classwise_metric(output, gt).cpu()
-
- return classwise_scores
-
- return weighted_metric
-
-
- jaccard_index = make_weighted_metric(classwise_iou)
- f1_score = make_weighted_metric(classwise_f1)
-
-
- if __name__ == '__main__':
- output, gt = torch.zeros(3, 2, 5, 5), torch.zeros(3, 5, 5).long()
- print(classwise_iou(output, gt))
|