|
- '''
- # @time:2023/2/27 16:17
- # Author:Tuan
- # @File:loss.py
- '''
- import torch
- from torch import nn
- import torch.nn.functional as F
- import torch
- import torch.nn as nn
- import math
- # --------------------------- BINARY LOSSES ---------------------------
- # class FocalLoss(nn.Module):
- # def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=255):
- # super(FocalLoss, self).__init__()
- # self.alpha = alpha
- # self.gamma = gamma
- # self.weight = weight
- # self.ignore_index = ignore_index
- # self.bce_fn = nn.BCEWithLogitsLoss(weight=self.weight)
- #
- # def forward(self, preds, labels):
- # if self.ignore_index is not None:
- # mask = labels != self.ignore
- # labels = labels[mask]
- # preds = preds[mask]
- #
- # logpt = -self.bce_fn(preds, labels)
- # pt = torch.exp(logpt)
- # loss = -((1 - pt) ** self.gamma) * self.alpha * logpt
- # return loss
- # --------------------------- MULTICLASS LOSSES ---------------------------
-
-
- # class FocalLoss(nn.Module):
- # def __init__(self, alpha=0.25, gamma=2, logits=False, sampling='mean'):
- # super(FocalLoss, self).__init__()
- # self.alpha = alpha
- # self.gamma = gamma
- # self.logits = logits
- # self.sampling = sampling
- # self.softmax = nn.Softmax(dim=1)
- #
- # def forward(self, y_pred, y_true):
- # alpha = self.alpha
- # alpha_ = (1 - self.alpha)
- # y_pred = self.softmax(y_pred)
- # # if self.logits:
- # # y_pred = torch.sigmoid(y_pred)
- #
- # pt_positive = torch.where(y_true == 1, y_pred, torch.ones_like(y_pred))
- # pt_negative = torch.where(y_true == 0, y_pred, torch.zeros_like(y_pred))
- # pt_positive = torch.clamp(pt_positive, 1e-3, .999)
- # pt_negative = torch.clamp(pt_negative, 1e-3, .999)
- # pos_ = (1 - pt_positive) ** self.gamma
- # neg_ = pt_negative ** self.gamma
- #
- # pos_loss = -alpha * pos_ * torch.log(pt_positive)
- # neg_loss = -alpha_ * neg_ * torch.log(1 - pt_negative)
- # loss = pos_loss + neg_loss
- #
- # if self.sampling == "mean":
- # return loss.mean()
- # elif self.sampling == "sum":
- # return loss.sum()
- # elif self.sampling == None:
- # return loss
-
-
- def diceCoeff(pred, gt, smooth=1e-5, activation='softmax2d'):
- r""" computational formula:
- dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
- """
-
- if activation is None or activation == "none":
- activation_fn = lambda x: x
- elif activation == "sigmoid":
- activation_fn = nn.Sigmoid()
- elif activation == "softmax2d":
- activation_fn = nn.Softmax2d()
- else:
- raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
-
- pred = activation_fn(pred)
-
- N = gt.size(0)
- pred_flat = pred.view(N, -1)
- gt_flat = gt.view(N, -1)
-
- intersection = (pred_flat * gt_flat).sum(1)
- unionset = pred_flat.sum(1) + gt_flat.sum(1)
- loss = (2 * intersection + smooth) / (unionset + smooth)
-
- return loss.sum() / N
- class SoftDiceLossV2(nn.Module):
- __name__ = 'dice_loss'
-
- def __init__(self, num_classes, activation='softmax2d', reduction='mean'):
- super(SoftDiceLossV2, self).__init__()
- self.activation = activation
- self.num_classes = num_classes
-
- def forward(self, y_pred, y_true):
- class_dice = []
- for i in range(1, self.num_classes):
- class_dice.append(diceCoeff(y_pred[:, i:i + 1, :], y_true[:, i:i + 1, :], activation=self.activation))
- mean_dice = sum(class_dice) / len(class_dice)
- return 1 - mean_dice
-
-
- class DSCLoss(torch.nn.Module):
-
- def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = "mean"):
- super().__init__()
- self.alpha = alpha
- self.smooth = smooth
- self.reduction = reduction
-
- def forward(self, logits, targets):
- probs = torch.softmax(logits, dim=1)
- probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))
-
- probs_with_factor = ((1 - probs) ** self.alpha) * probs
- loss = 1 - (2 * probs_with_factor + self.smooth) / (probs_with_factor + 1 + self.smooth)
-
- if self.reduction == "mean":
- return loss.mean()
- def softmax_focalloss(y_pred, y_true, ignore_index=255, gamma=2.0, normalize=False):
- """
-
- Args:
- y_pred: [N, #class, H, W]
- y_true: [N, H, W] from 0 to #class
- gamma: scalar
-
- Returns:
-
- """
- losses = F.cross_entropy(y_pred, y_true, ignore_index=ignore_index, reduction='none')
- with torch.no_grad():
- p = y_pred.softmax(dim=1)
- modulating_factor = (1 - p).pow(gamma)
- valid_mask = ~ y_true.eq(ignore_index)
- masked_y_true = torch.where(valid_mask, y_true, torch.zeros_like(y_true))
- modulating_factor = torch.gather(modulating_factor, dim=1, index=masked_y_true.unsqueeze(dim=1)).squeeze_(dim=1)
- scale = 1.
- if normalize:
- scale = losses.sum() / (losses * modulating_factor).sum()
- losses = scale * (losses * modulating_factor).sum() / (valid_mask.sum() + p.size(0))
-
- return losses
-
-
- def cosine_annealing(lower_bound, upper_bound, _t, _t_max):
- return upper_bound + 0.5 * (lower_bound - upper_bound) * (math.cos(math.pi * _t / _t_max) + 1)
-
-
- def poly_annealing(lower_bound, upper_bound, _t, _t_max):
- factor = (1 - _t / _t_max) ** 0.9
- return upper_bound + factor * (lower_bound - upper_bound)
-
-
- def linear_annealing(lower_bound, upper_bound, _t, _t_max):
- factor = 1 - _t / _t_max
- return upper_bound + factor * (lower_bound - upper_bound)
-
-
- def annealing_softmax_focalloss(y_pred, y_true, t, t_max, ignore_index=255, gamma=2.0,
- annealing_function=cosine_annealing):
- losses = F.cross_entropy(y_pred, y_true, ignore_index=ignore_index, reduction='none')
- with torch.no_grad():
- p = y_pred.softmax(dim=1)
- modulating_factor = (1 - p).pow(gamma)
- valid_mask = ~ y_true.eq(ignore_index)
- masked_y_true = torch.where(valid_mask, y_true, torch.zeros_like(y_true))
- modulating_factor = torch.gather(modulating_factor, dim=1, index=masked_y_true.unsqueeze(dim=1)).squeeze_(dim=1)
- normalizer = losses.sum() / (losses * modulating_factor).sum()
- scales = modulating_factor * normalizer
- if t > t_max:
- scale = scales
- else:
- scale = annealing_function(1, scales, t, t_max)
- losses = (losses * scale).sum() / (valid_mask.sum() + p.size(0))
- return losses
- # class FocalLoss(nn.Module):
- # def __init__(self, weight=None, gamma=0):
- # super(FocalLoss, self).__init__()
- # self.weight = weight
- # self.gamma = gamma
- # self.eps = 1e-8
- # def forward(self, predict, target):
- # if self.weight!=None:
- # weights = self.weight.unsqueeze(0).unsqueeze(1).repeat(predict.shape[0], predict.shape[2], 1)
- # target_onehot = F.one_hot(target.long(), predict.shape[1])
- # if self.weight!=None:
- # weights = torch.sum(target_onehot * weights, -1)
- # input_soft = F.softmax(predict, dim=1)
- # probs = torch.sum(input_soft.transpose(2, 1) * target_onehot, -1).clamp(min=0.001, max=0.999)#此处一定要限制范围,否则会出现loss为Nan的现象。
- # focal_weight = (1 + self.eps - probs) ** self.gamma
- # if self.weight!=None:
- # return torch.sum(-torch.log(probs) * weights * focal_weight) / torch.sum(weights)
- # else:
- # return torch.mean(-torch.log(probs) * focal_weight)
-
-
- # class FocalLoss(torch.nn.Module):
- # """
- # 二分类的Focalloss alpha 固定
- # """
- # def __init__(self, gamma, alpha, reduction='elementwise_mean'):
- # super().__init__()
- # self.gamma = gamma
- # self.alpha = alpha
- # self.reduction = reduction
- #
- # def forward(self, _input, target):
- # pt = torch.softmax(_input,dim=1)
- # alpha = self.alpha
- # loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
- # (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
- # if self.reduction == 'elementwise_mean':
- # loss = torch.mean(loss)
- # elif self.reduction == 'sum':
- # loss = torch.sum(loss)
- # return loss
-
- # class FocalLoss(nn.Module):
- # def __init__(self, alpha=0.25, gamma=2, weight=None, ignore_index=255):
- # super().__init__()
- # self.alpha = alpha
- # self.gamma = gamma
- # self.weight = weight
- # self.ignore_index = ignore_index
- # self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index)
- #
- # def forward(self, preds, labels):
- # logpt = -self.ce_fn(preds, labels)
- # pt = torch.exp(logpt)
- # loss = -((1 - pt) ** self.gamma) * self.alpha * logpt
- # return loss
-
- ### From https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65938
- class FocalLoss(nn.Module):
- def __init__(self, gamma=2, alpha=0.25, logits=False, reduce=True):
- super(FocalLoss, self).__init__()
- self.alpha = alpha
- self.gamma = gamma
- self.logits = logits
- self.reduce = reduce
-
- def forward(self, inputs, targets):
- if self.logits:
- BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
- else:
- BCE_loss = F.cross_entropy(inputs, targets, reduce=False)
- pt = torch.exp(-BCE_loss)
- F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
-
- if self.reduce:
- return torch.mean(F_loss)
- else:
- return F_loss
|