|
- # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
-
- from paddleseg.cvlibs import manager
-
-
- @manager.LOSSES.add_component
- class DetailAggregateLoss(nn.Layer):
- """
- DetailAggregateLoss's implementation based on PaddlePaddle.
-
- The original article refers to Meituan
- Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation."
- (https://arxiv.org/abs/2104.13188)
-
- Args:
- ignore_index (int64, optional): Specifies a target value that is ignored
- and does not contribute to the input gradient. Default ``255``.
-
- """
-
- def __init__(self, ignore_index=255):
- super(DetailAggregateLoss, self).__init__()
- self.ignore_index = ignore_index
- self.laplacian_kernel = paddle.to_tensor(
- [-1, -1, -1, -1, 8, -1, -1, -1, -1], dtype='float32').reshape(
- (1, 1, 3, 3))
- self.fuse_kernel = paddle.create_parameter(
- [1, 3, 1, 1], dtype='float32')
-
- def forward(self, logits, label):
- """
- Args:
- logits (Tensor): Logit tensor, the data type is float32, float64. Shape is
- (N, C), where C is number of classes, and if shape is more than 2D, this
- is (N, C, D1, D2,..., Dk), k >= 1.
- label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
- value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
- (N, D1, D2,..., Dk), k >= 1.
- Returns: loss
- """
- boundary_targets = F.conv2d(
- paddle.unsqueeze(
- label, axis=1).astype('float32'),
- self.laplacian_kernel,
- padding=1)
- boundary_targets = paddle.clip(boundary_targets, min=0)
- boundary_targets = boundary_targets > 0.1
- boundary_targets = boundary_targets.astype('float32')
-
- boundary_targets_x2 = F.conv2d(
- paddle.unsqueeze(
- label, axis=1).astype('float32'),
- self.laplacian_kernel,
- stride=2,
- padding=1)
- boundary_targets_x2 = paddle.clip(boundary_targets_x2, min=0)
- boundary_targets_x4 = F.conv2d(
- paddle.unsqueeze(
- label, axis=1).astype('float32'),
- self.laplacian_kernel,
- stride=4,
- padding=1)
- boundary_targets_x4 = paddle.clip(boundary_targets_x4, min=0)
-
- boundary_targets_x8 = F.conv2d(
- paddle.unsqueeze(
- label, axis=1).astype('float32'),
- self.laplacian_kernel,
- stride=8,
- padding=1)
- boundary_targets_x8 = paddle.clip(boundary_targets_x8, min=0)
-
- boundary_targets_x8_up = F.interpolate(
- boundary_targets_x8, boundary_targets.shape[2:], mode='nearest')
- boundary_targets_x4_up = F.interpolate(
- boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
- boundary_targets_x2_up = F.interpolate(
- boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')
-
- boundary_targets_x2_up = boundary_targets_x2_up > 0.1
- boundary_targets_x2_up = boundary_targets_x2_up.astype('float32')
-
- boundary_targets_x4_up = boundary_targets_x4_up > 0.1
- boundary_targets_x4_up = boundary_targets_x4_up.astype('float32')
-
- boundary_targets_x8_up = boundary_targets_x8_up > 0.1
- boundary_targets_x8_up = boundary_targets_x8_up.astype('float32')
-
- boudary_targets_pyramids = paddle.stack(
- (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
- axis=1)
-
- boudary_targets_pyramids = paddle.squeeze(
- boudary_targets_pyramids, axis=2)
- boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids,
- self.fuse_kernel)
-
- boudary_targets_pyramid = boudary_targets_pyramid > 0.1
- boudary_targets_pyramid = boudary_targets_pyramid.astype('float32')
-
- if logits.shape[-1] != boundary_targets.shape[-1]:
- logits = F.interpolate(
- logits,
- boundary_targets.shape[2:],
- mode='bilinear',
- align_corners=True)
-
- bce_loss = F.binary_cross_entropy_with_logits(logits,
- boudary_targets_pyramid)
- dice_loss = self.fixed_dice_loss_func(
- F.sigmoid(logits), boudary_targets_pyramid)
- detail_loss = bce_loss + dice_loss
-
- label.stop_gradient = True
- return detail_loss
-
- def fixed_dice_loss_func(self, input, target):
- """
- simplified diceloss for DetailAggregateLoss.
- """
- smooth = 1.
- n = input.shape[0]
- iflat = paddle.reshape(input, [n, -1])
- tflat = paddle.reshape(target, [n, -1])
- intersection = paddle.sum((iflat * tflat), axis=1)
- loss = 1 - (
- (2. * intersection + smooth) /
- (paddle.sum(iflat, axis=1) + paddle.sum(tflat, axis=1) + smooth))
- return paddle.mean(loss)
|