|
- """ cross entropy smooth """
- from mindspore import Tensor, nn
- from mindspore.ops import functional as F
- from mindspore.ops import operations as P
-
-
- class BinaryCrossEntropySmooth(nn.LossBase):
- """
- Binary cross entropy loss with label smoothing.
- Apply sigmoid activation function to input `logits`, and uses the given logits to compute binary cross entropy
- between the logits and the label.
-
- Args:
- smoothing: Label smoothing factor, a regularization tool used to prevent the model
- from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default: 0.0.
- aux_factor: Auxiliary loss factor. Set aux_factor > 0.0 if the model has auxiliary logit outputs
- (i.e., deep supervision), like inception_v3. Default: 0.0.
- reduction: Apply specific reduction method to the output: 'mean' or 'sum'. Default: 'mean'.
- weight (Tensor): Class weight. A rescaling weight applied to the loss of each batch element. Shape [C].
- It can be broadcast to a tensor with shape of `logits`. Data type must be float16 or float32.
- pos_weight (Tensor): Positive weight for each class. A weight of positive examples. Shape [C].
- Must be a vector with length equal to the number of classes.
- It can be broadcast to a tensor with shape of `logits`. Data type must be float16 or float32.
-
- Inputs:
- logits (Tensor or Tuple of Tensor): (1) Input logits. Shape [N, C], where N is # samples, C is # classes.
- Or (2) Tuple of two input logits (main_logits and aux_logits) for auxiliary loss.
- labels (Tensor): Ground truth label, (1) shape [N, C], has the same shape as `logits` or (2) shape [N].
- can be a class probability matrix or one-hot labels. Data type must be float16 or float32.
- """
-
- def __init__(self, smoothing=0.0, aux_factor=0.0, reduction="mean", weight=None, pos_weight=None):
- super().__init__()
- self.smoothing = smoothing
- self.aux_factor = aux_factor
- self.reduction = reduction
- self.weight = weight
- self.pos_weight = pos_weight
- self.ones = P.OnesLike()
- self.one_hot = P.OneHot()
-
- def construct(self, logits, labels):
- loss_aux = 0
- aux_logits = None
-
- if isinstance(logits, tuple):
- main_logits = logits[0]
- else:
- main_logits = logits
-
- if main_logits.size != labels.size:
- # We must explicitly convert the label to one-hot,
- # for binary_cross_entropy_with_logits restricting input and label have the same shape.
- class_dim = 0 if main_logits.ndim == 1 else 1
- n_classes = main_logits.shape[class_dim]
- labels = self.one_hot(labels, n_classes, Tensor(1.0), Tensor(0.0))
-
- ones_input = self.ones(main_logits)
- if self.weight is not None:
- weight = self.weight
- else:
- weight = ones_input
- if self.pos_weight is not None:
- pos_weight = self.pos_weight
- else:
- pos_weight = ones_input
-
- if self.smoothing > 0.0:
- class_dim = 0 if main_logits.ndim == 1 else -1
- n_classes = main_logits.shape[class_dim]
- labels = labels * (1 - self.smoothing) + self.smoothing / n_classes
-
- if self.aux_factor > 0 and aux_logits is not None:
- for aux_logits in logits[1:]:
- loss_aux += F.binary_cross_entropy_with_logits(
- aux_logits, labels, weight=weight, pos_weight=pos_weight, reduction=self.reduction
- )
- # else:
- # warnings.warn("There are logit tuple input, but the auxiliary loss factor is 0.")
-
- loss_logits = F.binary_cross_entropy_with_logits(
- main_logits, labels, weight=weight, pos_weight=pos_weight, reduction=self.reduction
- )
-
- loss = loss_logits + self.aux_factor * loss_aux
-
- return loss
|