|
- # ------------------------------------------------------------------------------ #
- # ------------------------------------------------------------------------------ #
- # OmniPose #
- # Rochester Institute of Technology - Vision and Image Processing Lab #
- # Bruno Artacho (bmartacho@mail.rit.edu) #
- # ------------------------------------------------------------------------------ #
- # ------------------------------------------------------------------------------ #
- import math
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore.common import initializer as init
- from src.utils.init import KaimingNormal
-
- class SepConv2d(nn.Cell):
- def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0,
- dilation=1, bias=True, depth_multiplier=1):
- super(SepConv2d, self).__init__()
-
- intermediate_channels = in_channels * depth_multiplier
-
- self.spatialConv = nn.Conv2d(in_channels, intermediate_channels,kernel_size, stride,pad_mode='pad',
- padding=padding, dilation=dilation, group=in_channels, has_bias=bias)
-
- self.pointConv = nn.Conv2d(intermediate_channels, out_channels,
- kernel_size=1, stride=1, pad_mode='valid', dilation=1, has_bias=bias)
-
- self.relu = nn.ReLU()
-
- for _, cell in self.cells_and_names():
- if isinstance(cell, nn.Conv2d):
- cell.weight.set_data(init.initializer(KaimingNormal(a=math.sqrt(5), mode='fan_out',
- nonlinearity='relu'),
- cell.weight.shape,
- cell.weight.dtype))
- elif isinstance(cell, nn.BatchNorm2d):
- cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
- cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
- elif isinstance(cell, nn.Dense):
- cell.bias.set_data(init.initializer('zeros', cell.bias.shape))
-
- def construct(self, x):
- x = self.spatialConv(x)
- x = self.relu(x)
- x = self.pointConv(x)
-
- return x
-
- conv_dict = {
- 'CONV2D': nn.Conv2d,
- 'SEPARABLE': SepConv2d
- }
- class AdaptiveAvgPool2d(nn.Cell):
- """AdaptiveAvgPool2d"""
- def __init__(self):
- super(AdaptiveAvgPool2d, self).__init__()
- self.mean = ops.ReduceMean(True)
-
- def construct(self, x):
- x = self.mean(x, (2, 3))
- return x
-
- class _AtrousModule(nn.Cell):
- def __init__(self, conv_type, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
- super(_AtrousModule, self).__init__()
- self.conv = conv_dict[conv_type]
- self.atrous_conv = self.conv(inplanes, planes, kernel_size=kernel_size,
- stride=1, padding=padding, dilation=dilation)
-
- self.bn = BatchNorm(planes)
- self.relu = nn.ReLU()
-
- def construct(self, x):
- x = self.atrous_conv(x)
- x = self.bn(x)
-
- return self.relu(x)
-
- class WASPv2(nn.Cell):
- def __init__(self, conv_type, inplanes, planes, n_classes=17):
- super(WASPv2, self).__init__()
-
- # WASP
- dilations = [1, 6, 12, 18]
- # dilations = [1, 12, 24, 36]
-
- # convs = conv_dict[conv_type]
-
- reduction = planes // 8
-
- BatchNorm = nn.BatchNorm2d
-
- self.aspp1 = _AtrousModule(conv_type, inplanes, planes, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
- self.aspp2 = _AtrousModule(conv_type, planes, planes, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
- self.aspp3 = _AtrousModule(conv_type, planes, planes, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
- self.aspp4 = _AtrousModule(conv_type, planes, planes, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
-
- self.relu = nn.ReLU()
-
- self.global_avg_pool = nn.SequentialCell([AdaptiveAvgPool2d(),
- nn.Conv2d(planes, planes, 1, stride=1, pad_mode='valid'),
- nn.BatchNorm2d(planes),
- nn.ReLU()])
- self.interpolate = nn.ResizeBilinear()
- self.conv1 = nn.Conv2d(5*planes, planes, 1)
- self.bn1 = nn.BatchNorm2d(planes)
- self.concat=ops.Concat(axis=1)
- # adopt [1x1, 48] for channel reduction.
- self.conv2 = nn.Conv2d(256, reduction, 1)
- self.bn2 = nn.BatchNorm2d(reduction)
-
- self.last_conv = nn.SequentialCell([nn.Conv2d(planes+reduction, planes, kernel_size=3, stride=1,pad_mode='pad', padding=1),
- nn.BatchNorm2d(planes),
- nn.ReLU(),
- nn.Conv2d(planes, planes, kernel_size=3, stride=1,pad_mode='pad', padding=1),
- nn.BatchNorm2d(planes),
- nn.ReLU(),
- nn.Conv2d(planes, n_classes, kernel_size=1, stride=1)])
-
- for _, cell in self.cells_and_names():
- if isinstance(cell, nn.Conv2d):
- cell.weight.set_data(init.initializer(KaimingNormal(a=math.sqrt(5), mode='fan_out',
- nonlinearity='relu'),
- cell.weight.shape,
- cell.weight.dtype))
- elif isinstance(cell, nn.BatchNorm2d):
- cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
- cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
- elif isinstance(cell, nn.Dense):
- cell.bias.set_data(init.initializer('zeros', cell.bias.shape))
-
- def construct(self, x, low_level_features):
- x1 = self.aspp1(x)
- x2 = self.aspp2(x1)
- x3 = self.aspp3(x2)
- x4 = self.aspp4(x3)
- x5 = self.global_avg_pool(x)
- x5 = self.interpolate(x5, size=x4.shape[2:], align_corners=True)
-
- x = self.concat((x1, x2, x3, x4, x5))
-
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- low_level_features = self.conv2(low_level_features)
- low_level_features = self.bn2(low_level_features)
- low_level_features = self.relu(low_level_features)
-
- x = self.concat((x, low_level_features))
- x = self.last_conv(x)
-
- return x
|