|
- # Copyright 2023 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """retinanet based resnet."""
-
- import mindspore.common.dtype as mstype
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context, Tensor
- from mindspore.context import ParallelMode
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- from mindspore.communication.management import get_group_size
- from mindspore import ops
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.ops import composite as C
-
- from .bottleneck import FPN
-
- import mindspore.numpy as np
-
-
- class FlattenConcat(nn.Cell):
- """
- Concatenate predictions into a single tensor.
-
- Args:
- config (dict): The default config of retinanet.
-
- Returns:
- Tensor, flatten predictions.
- """
-
- def __init__(self, config):
- super(FlattenConcat, self).__init__()
- self.num_retinanet_boxes = config.num_retinanet_boxes
- self.concat = P.Concat(axis=1)
- self.transpose = P.Transpose()
-
- def construct(self, inputs):
- output = ()
- batch_size = F.shape(inputs[0])[0]
- for x in inputs:
- x = self.transpose(x, (0, 2, 3, 1))
- output += (F.reshape(x, (batch_size, -1)),)
- res = self.concat(output)
- return F.reshape(res, (batch_size, self.num_retinanet_boxes, -1))
-
-
- def ClassificationModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', num_classes=80,
- feature_size=256):
- conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
- conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- conv5 = nn.Conv2d(feature_size, num_anchors * num_classes, kernel_size=3, pad_mode='same')
- return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU(), conv5])
-
-
- def RegressionModel(in_channel, num_anchors, kernel_size=3, stride=1, pad_mod='same', feature_size=256):
- conv1 = nn.Conv2d(in_channel, feature_size, kernel_size=3, pad_mode='same')
- conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, pad_mode='same')
- return nn.SequentialCell([conv1, nn.ReLU(), conv2, nn.ReLU(), conv3, nn.ReLU(), conv4, nn.ReLU()])
-
-
- class MultiBox(nn.Cell):
- """
- Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
-
- Args:
- config (dict): The default config of retinanet.
-
- Returns:
- Tensor, localization predictions.
- Tensor, class conf scores.
- """
-
- def __init__(self, config):
- super(MultiBox, self).__init__()
- out_channels = config.extras_out_channels
- num_default = config.num_default
- self.multi_loc_layer = RegressionModel(in_channel=out_channels[0], num_anchors=num_default[0])
- self.multi_cls_layer = ClassificationModel(in_channel=out_channels[0], num_anchors=num_default[0])
-
- def construct(self, inputs):
- loc_outputs = ()
- cls_outputs = ()
- for i in range(5):
- loc_outputs += (self.multi_loc_layer(inputs[i]),)
- cls_outputs += (self.multi_cls_layer(inputs[i]),)
- return loc_outputs, cls_outputs
-
- class SigmoidFocalClassificationLoss(nn.Cell):
- def __init__(self, gamma=2.0, alpha=0.25, reduction='mean'):
- super(SigmoidFocalClassificationLoss, self).__init__()
- self.sigmoid = ops.Sigmoid()
- self.alpha = alpha
- self.gamma = gamma
- self.reduction = reduction
- self.binary_cross_entropy_with_logits = nn.BCEWithLogitsLoss(reduction="none")
- self.onehot = P.OneHot()
-
- def construct(self, pred, target):
- pred_sigmoid = self.sigmoid(pred)
- target = self.onehot(target, pred.shape[2] + 1, Tensor(1.0, mstype.float32), Tensor(0.0, mstype.float32))
- target = target[...,1:]
- target = ops.cast(target, pred.dtype)
- pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
- focal_weight = (self.alpha * target + (1 - self.alpha) * (1 - target)) * ops.pow(pt, self.gamma)
- loss = self.binary_cross_entropy_with_logits(pred, target) * focal_weight
- return loss
-
-
- class sablhead(nn.Cell):
- """retinahead"""
- def __init__(self, backbone, config, is_training=True):
- super(sablhead, self).__init__()
-
- self.config = config
- self.fpn = FPN(backbone=backbone, config=config)
- self.multi_box = MultiBox(config)
- self.is_training = is_training
- self.conv_loc = nn.Conv2d(256, 28, kernel_size=3, pad_mode='same', has_bias=True)
- self.conv_buk = nn.Conv2d(256, 28, kernel_size=3, pad_mode='same', has_bias=True)
- if not is_training:
- self.activation = P.Sigmoid()
-
- def construct(self, inputs):
- features = self.fpn(inputs)
- pred_loc, pred_label = self.multi_box(features)
- pred_loc_res = ()
- pred_buk_res = ()
- for i in range(len(pred_loc)):
- pred_loc_res += (self.conv_loc(pred_loc[i]),)
- pred_buk_res += (self.conv_buk(pred_loc[i]),)
- pred_loc_res = FlattenConcat(self.config)(pred_loc_res)
- pred_label_res = FlattenConcat(self.config)(pred_label)
- pred_buk_res = FlattenConcat(self.config)(pred_buk_res)
- return pred_loc_res, pred_label_res, pred_buk_res
-
-
- class sablWithLossCell(nn.Cell):
- def __init__(self, network, config):
- super(sablWithLossCell, self).__init__()
- self.network = network
- self.less = P.Less()
- self.tile = P.Tile()
- self.reduce_sum = P.ReduceSum()
- self.reduce_mean = P.ReduceMean()
- self.expand_dims = P.ExpandDims()
- self.loc_loss = nn.SmoothL1Loss(beta=1/9,reduction='none')
-
- def construct(self, x, gt_loc, gt_loc_weights, gt_label, gt_buk, gt_buk_weights, num_matched_boxes):
- """construct"""
- pred_loc, pred_label, pred_buk = self.network(x)
- num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
-
- # Localization Loss
- smooth_l1 = self.loc_loss(pred_loc, gt_loc) * gt_loc_weights
- loss_loc = self.reduce_sum(smooth_l1) / 8.0 / num_matched_boxes
-
- # Classification Loss
- mask_cls = F.cast(self.less(-1, gt_label), mstype.float32)
- mask_cls = self.tile(self.expand_dims(mask_cls, -1), (1, 1, 80))
- focal_loss = SigmoidFocalClassificationLoss()
- loss_cls = focal_loss(pred_label, gt_label)
- loss_cls = loss_cls * mask_cls
- loss_cls = loss_cls.sum() / num_matched_boxes
-
- # Bucket Loss
- bce_loss = nn.BCEWithLogitsLoss(reduction='none')
- loss_buk = bce_loss(pred_buk, gt_buk)
- loss_buk = loss_buk * gt_buk_weights
- loss_buk = loss_buk.sum()
- loss_buk = loss_buk / num_matched_boxes / 4 / 7
-
- return (loss_cls + 1.5 * loss_loc + 1.5 * loss_buk)
-
- class TrainingWrapper(nn.Cell):
- def __init__(self, network, optimizer, sens=1.0):
- super(TrainingWrapper, self).__init__(auto_prefix=False)
- self.network = network
- self.network.set_grad()
- self.weights = ms.ParameterTuple(network.trainable_params())
- self.optimizer = optimizer
- self.grad = C.GradOperation(get_by_list=True, sens_param=True)
- self.sens = sens
- self.reducer_flag = False
- self.grad_reducer = None
- self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- self.reducer_flag = True
- if self.reducer_flag:
- mean = context.get_auto_parallel_context("gradients_mean")
- if auto_parallel_context().get_device_num_is_set():
- degree = context.get_auto_parallel_context("device_num")
- else:
- degree = get_group_size()
- self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
-
- def construct(self, *args):
- weights = self.weights
- loss = self.network(*args)
- sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
- grads = self.grad(self.network, weights)(*args, sens)
- if self.reducer_flag:
- # apply grad reducer on grads
- grads = self.grad_reducer(grads)
- self.optimizer(grads)
- return loss
-
- class SablInferWithDecoder(nn.Cell):
- """
- retinanet Infer wrapper to decode the bbox locations.
-
- Args:
- network (Cell): the origin retinanet infer network without bbox decoder.
- default_boxes (Tensor): the default_boxes from anchor generator
- config (dict): retinanet config
- Returns:
- Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1)
- Tensor, the prediction labels.
-
- """
- def __init__(self, network, default_boxes, config):
- super(SablInferWithDecoder, self).__init__()
- self.network = network
- self.default_boxes = default_boxes
- def construct(self, x):
- """construct"""
- bbox_reg_pred, scores, bbox_cls_pred = self.network(x)
- bbox_reg_pred, scores, bbox_cls_pred = bbox_reg_pred.squeeze(), scores.squeeze(), bbox_cls_pred.squeeze()
- scores = ops.Sigmoid()(scores)
- return bbox_reg_pred, scores, bbox_cls_pred
|