|
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """rmi loss in PaddlePaddle"""
- import numpy
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
-
- from paddleseg.cvlibs import manager
-
- _euler_num = 2.718281828
- _pi = 3.14159265
- _ln_2_pi = 1.837877
- _CLIP_MIN = 1e-6
- _CLIP_MAX = 1.0
- _POS_ALPHA = 5e-4
- _IS_SUM = 1
-
-
- @manager.LOSSES.add_component
- class RMILoss(nn.Layer):
- """
- Implements the Region Mutual Information(RMI) Loss(https://arxiv.org/abs/1910.12037) for Semantic Segmentation.
- Unlike vanilla rmi loss which contains Cross Entropy Loss, we disband them and only
- left the RMI-related parts.
- The motivation is to allow for a more flexible combination of losses during training.
- For example, by employing mixed loss to merge RMI Loss with Boostrap Cross Entropy Loss,
- we can achieve the online mining of hard examples together with attention to region information.
- Args:
- weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
- given to each class. Its length must be equal to the number of classes.
- Default ``None``.
- ignore_index (int64, optional): Specifies a target value that is ignored
- and does not contribute to the input gradient. Default ``255``.
- """
-
- def __init__(self,
- num_classes=19,
- rmi_radius=3,
- rmi_pool_way=0,
- rmi_pool_size=3,
- rmi_pool_stride=3,
- loss_weight_lambda=0.5,
- ignore_index=255):
- super(RMILoss, self).__init__()
-
- self.num_classes = num_classes
- assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- self.rmi_radius = rmi_radius
- assert rmi_pool_way in [0, 1, 2, 3]
- self.rmi_pool_way = rmi_pool_way
- assert rmi_pool_size == rmi_pool_stride
- self.rmi_pool_size = rmi_pool_size
- self.rmi_pool_stride = rmi_pool_stride
- self.weight_lambda = loss_weight_lambda
- self.half_d = self.rmi_radius * self.rmi_radius
- self.d = 2 * self.half_d
- self.kernel_padding = self.rmi_pool_size // 2
- self.ignore_index = ignore_index
-
- def forward(self, logits_4D, labels_4D, do_rmi=True):
- """
- Forward computation.
- Args:
- logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty).
- labels (Tensor): Shape is [N, H, W], ground truth labels (between 0 and C - 1).
- """
- logits_4D = paddle.cast(logits_4D, dtype='float32')
- labels_4D = paddle.cast(labels_4D, dtype='float32')
-
- loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi)
- return loss
-
- def forward_sigmoid(self, logits_4D, labels_4D, do_rmi=False):
- """
- Using the sigmiod operation both.
- Args:
- logits_4D : [N, C, H, W], dtype=float32
- labels_4D : [N, H, W], dtype=long
- do_rmi : bool
- """
- label_mask_3D = labels_4D != self.ignore_index
- valid_onehot_labels_4D = paddle.cast(
- F.one_hot(
- paddle.cast(
- labels_4D, dtype='int64') * paddle.cast(
- label_mask_3D, dtype='int64'),
- num_classes=self.num_classes),
- dtype='float32')
- # label_mask_flat = paddle.cast(
- # paddle.reshape(label_mask_3D, [-1]), dtype='float32')
-
- valid_onehot_labels_4D = valid_onehot_labels_4D * paddle.unsqueeze(
- label_mask_3D, axis=3)
- valid_onehot_labels_4D.stop_gradient = True
- probs_4D = F.sigmoid(logits_4D) * paddle.unsqueeze(
- label_mask_3D, axis=1) + _CLIP_MIN
-
- valid_onehot_labels_4D = paddle.transpose(valid_onehot_labels_4D,
- [0, 3, 1, 2])
- valid_onehot_labels_4D.stop_gradient = True
- rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D)
-
- return rmi_loss
-
- def inverse(self, x):
- return paddle.inverse(x)
-
- def rmi_lower_bound(self, labels_4D, probs_4D):
- """
- calculate the lower bound of the region mutual information.
- Args:
- labels_4D : [N, C, H, W], dtype=float32
- probs_4D : [N, C, H, W], dtype=float32
- """
- assert labels_4D.shape == probs_4D.shape, print(
- 'shapes', labels_4D.shape, probs_4D.shape)
-
- p, s = self.rmi_pool_size, self.rmi_pool_stride
- if self.rmi_pool_stride > 1:
- if self.rmi_pool_way == 0:
- labels_4D = F.max_pool2d(
- labels_4D,
- kernel_size=p,
- stride=s,
- padding=self.kernel_padding)
- probs_4D = F.max_pool2d(
- probs_4D,
- kernel_size=p,
- stride=s,
- padding=self.kernel_padding)
- elif self.rmi_pool_way == 1:
- labels_4D = F.avg_pool2d(
- labels_4D,
- kernel_size=p,
- stride=s,
- padding=self.kernel_padding)
- probs_4D = F.avg_pool2d(
- probs_4D,
- kernel_size=p,
- stride=s,
- padding=self.kernel_padding)
- elif self.rmi_pool_way == 2:
- shape = labels_4D.shape
- new_h, new_w = shape[2] // s, shape[3] // s
- labels_4D = F.interpolate(
- labels_4D, size=[new_h, new_w], mode='nearest')
- probs_4D = F.interpolate(
- probs_4D,
- size=[new_h, new_w],
- mode='bilinear',
- align_corners=True)
- else:
- raise NotImplementedError("Pool way of RMI is not defined!")
-
- label_shape = labels_4D.shape
- n, c = label_shape[0], label_shape[1]
-
- la_vectors, pr_vectors = self.map_get_pairs(
- labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0)
-
- la_vectors = paddle.reshape(la_vectors, [n, c, self.half_d, -1])
- la_vectors = paddle.cast(la_vectors, dtype='float64')
- la_vectors.stop_gradient = True
-
- pr_vectors = paddle.reshape(pr_vectors, [n, c, self.half_d, -1])
- pr_vectors = paddle.cast(pr_vectors, dtype='float64')
-
- diag_matrix = paddle.unsqueeze(
- paddle.unsqueeze(
- paddle.eye(self.half_d), axis=0), axis=0)
- la_vectors = la_vectors - paddle.mean(la_vectors, axis=3, keepdim=True)
-
- la_cov = paddle.matmul(la_vectors,
- paddle.transpose(la_vectors, [0, 1, 3, 2]))
- pr_vectors = pr_vectors - paddle.mean(pr_vectors, axis=3, keepdim=True)
- pr_cov = paddle.matmul(pr_vectors,
- paddle.transpose(pr_vectors, [0, 1, 3, 2]))
-
- pr_cov_inv = self.inverse(pr_cov + paddle.cast(
- diag_matrix, dtype='float64') * _POS_ALPHA)
-
- la_pr_cov = paddle.matmul(la_vectors,
- paddle.transpose(pr_vectors, [0, 1, 3, 2]))
-
- appro_var = la_cov - paddle.matmul(
- paddle.matmul(la_pr_cov, pr_cov_inv),
- paddle.transpose(la_pr_cov, [0, 1, 3, 2]))
-
- rmi_now = 0.5 * self.log_det_by_cholesky(appro_var + paddle.cast(
- diag_matrix, dtype='float64') * _POS_ALPHA)
-
- rmi_per_class = paddle.cast(
- paddle.mean(
- paddle.reshape(rmi_now, [-1, self.num_classes]), axis=0),
- dtype='float32')
- rmi_per_class = paddle.divide(rmi_per_class,
- paddle.to_tensor(float(self.half_d)))
-
- rmi_loss = paddle.sum(rmi_per_class) if _IS_SUM else paddle.mean(
- rmi_per_class)
-
- return rmi_loss
-
- def log_det_by_cholesky(self, matrix):
- """
- Args:
- matrix: matrix must be a positive define matrix.
- shape [N, C, D, D].
- """
-
- chol = paddle.cholesky(matrix)
- diag = paddle.diagonal(chol, offset=0, axis1=-2, axis2=-1)
- chol = paddle.log(diag + 1e-8)
-
- return 2.0 * paddle.sum(chol, axis=-1)
-
- def map_get_pairs(self, labels_4D, probs_4D, radius=3, is_combine=True):
- """
- Args:
- labels_4D : labels, shape [N, C, H, W]
- probs_4D : probabilities, shape [N, C, H, W]
- radius : the square radius
- Return:
- tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)]
- """
-
- label_shape = labels_4D.shape
- h, w = label_shape[2], label_shape[3]
- new_h, new_w = h - (radius - 1), w - (radius - 1)
- la_ns = []
- pr_ns = []
- for y in range(0, radius, 1):
- for x in range(0, radius, 1):
- la_now = labels_4D[:, :, y:y + new_h, x:x + new_w]
- pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w]
- la_ns.append(la_now)
- pr_ns.append(pr_now)
-
- if is_combine:
- pair_ns = la_ns + pr_ns
- p_vectors = paddle.stack(pair_ns, axis=2)
- return p_vectors
- else:
- la_vectors = paddle.stack(la_ns, axis=2)
- pr_vectors = paddle.stack(pr_ns, axis=2)
- return la_vectors, pr_vectors
|