|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """model architecture of VargFace"""
-
- import numpy as np
- import mindspore.nn as nn
- from mindspore import Tensor
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.common import dtype as mstype
- from mindspore.common.parameter import Parameter
- from mindspore.common.initializer import initializer
-
-
- def weight_variable(shape):
- return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
-
-
- def weight_variable_0(shape):
- zeros = np.zeros(shape).astype(np.float32)
- return Tensor(zeros)
-
-
- def weight_variable_1(shape):
- ones = np.ones(shape).astype(np.float32)
- return Tensor(ones)
-
-
- def weight_variable_uniform(shape):
- return initializer('Uniform', shape=shape, dtype=mstype.float32)
-
-
- def fc_with_initialize(input_channels, out_channels):
- weight_shape = (out_channels, input_channels)
- weight = weight_variable(weight_shape)
- bias_shape = (out_channels)
- bias = weight_variable_uniform(bias_shape)
- return nn.Dense(input_channels, out_channels, weight, bias)
-
-
- def g_conv(inp, oup, k, s, pad, groups, bias=False):
- if inp == oup:
- groups = inp
- else:
- groups = 1
- if isinstance(k, int):
- weight_shape = (oup, int(inp / groups), k, k)
- else:
- weight_shape = (oup, int(inp / groups), k[0], k[1])
- weight = weight_variable(weight_shape)
- return nn.Conv2d(inp, oup, kernel_size=k,
- stride=s, padding=pad, weight_init=weight,
- has_bias=bias, pad_mode="pad", group=groups)
-
-
- def bn(oup):
- return nn.BatchNorm2d(oup)
-
-
- class VarGConv(nn.Cell):
- def __init__(self, inp, oup, k, s, S, linear=False):
- super(VarGConv, self).__init__()
- self.linear = linear
- self.conv = g_conv(inp, oup, k, s, pad=k // 2, groups=inp // S, bias=False)
- self.bn = bn(oup)
- if not linear:
- self.prelu = P.ReLU()
-
- def construct(self, x):
- x = self.conv(x)
- x = self.bn(x)
- if self.linear:
- output = x
- else:
- output = self.prelu(x)
- return output
-
-
- class PointConv(nn.Cell):
- def __init__(self, inp, oup, s, S, isPReLU):
- super(PointConv).__init__()
- self.isPReLU = isPReLU
- super(PointConv, self).__init__()
- self.conv = g_conv(inp, oup, 1, s, pad=0, groups=inp // S, bias=False)
- self.bn = bn(oup)
- if isPReLU:
- self.prelu = P.ReLU()
-
- def construct(self, x):
- x = self.conv(x)
- x = self.bn(x)
- if not self.isPReLU:
- output = x
- else:
- output = self.prelu(x)
- return output
-
-
- class SqueezeAndExcite(nn.Cell):
- def __init__(self, inp, oup, divide=4):
- super(SqueezeAndExcite, self).__init__()
- mid_c = inp // divide
- self.pool = P.ReduceMean(keep_dims=True)
- self.squeeze = P.Squeeze(axis=(2, 3))
- self.linear1 = fc_with_initialize(inp, mid_c)
- self.linear2 = fc_with_initialize(mid_c, oup)
- self.relu = nn.ReLU6()
-
- def construct(self, x):
- out = self.pool(x, (2, 3))
- out = self.squeeze(out)
- out = self.linear1(out)
- out = self.relu(out)
- out = self.linear2(out)
- out = self.relu(out)
- out = F.reshape(out, (x.shape[0], x.shape[1], 1, 1))
- return out * x
-
-
- class NormalBlock(nn.Cell):
- def __init__(self, inp, k, s=1, S=8):
- super(NormalBlock, self).__init__()
- oup = 2 * inp
- self.vargconv1 = VarGConv(inp, oup, k, s, S)
- self.pointconv1 = PointConv(oup, inp, s, S, isPReLU=True)
- self.vargconv2 = VarGConv(inp, oup, k, s, S)
- self.pointconv2 = PointConv(oup, inp, s, S, isPReLU=False)
- self.se = SqueezeAndExcite(inp, inp)
- self.prelu = nn.ReLU()
-
- def construct(self, x):
- out = x
- x = self.pointconv1(self.vargconv1(x))
- x = self.pointconv2(self.vargconv2(x))
- x = self.se(x)
- out = out + x
- return self.prelu(out)
-
-
- class DownSampling(nn.Cell):
- def __init__(self, inp, k, s=2, S=8):
- super(DownSampling, self).__init__()
- oup = 2 * inp
- self.vargconv1 = VarGConv(inp, oup, k, s, S)
- self.pointconv1 = PointConv(oup, oup, 1, S, isPReLU=True)
- self.vargconv2 = VarGConv(inp, oup, k, s, S)
- self.pointconv2 = PointConv(oup, oup, 1, S, isPReLU=True)
- self.vargconv3 = VarGConv(oup, 2 * oup, k, 1, S) # stride =1
- self.pointconv3 = PointConv(2 * oup, oup, 1, S, isPReLU=False)
- self.vargconv4 = VarGConv(inp, oup, k, s, S)
- self.pointconv4 = PointConv(oup, oup, 1, S, isPReLU=False)
- self.prelu = nn.ReLU()
-
- def construct(self, x):
- out = self.vargconv4(x)
- out = self.pointconv4(out)
- x1 = x2 = x
- x1 = self.vargconv1(x1)
- x1 = self.pointconv1(x1)
- x2 = self.vargconv2(x2)
- x2 = self.pointconv2(x2)
- x3 = x1 + x2
- x3 = self.vargconv3(x3)
- x3 = self.pointconv3(x3)
- out = out + x3
- return self.prelu(out)
-
-
- class HeadSetting(nn.Cell):
- def __init__(self, inp, k, S=8):
- super(HeadSetting, self).__init__()
- self.vargconv1 = VarGConv(inp, inp, k, 2, S)
- self.pointconv1 = PointConv(inp, inp, 1, S, isPReLU=True)
- self.vargconv2 = VarGConv(inp, inp, k, 1, S)
- self.pointconv2 = PointConv(inp, inp, 1, S, isPReLU=False)
-
- self.vargconv3 = VarGConv(inp, inp, k, 2, S)
- self.pointconv3 = PointConv(inp, inp, 1, S, isPReLU=False)
-
- def construct(self, x):
- out = self.vargconv3(x)
- out = self.pointconv3(out)
- x = self.vargconv1(x)
- x = self.pointconv1(x)
- x = self.vargconv2(x)
- x = self.pointconv2(x)
- out = out + x
- return out
-
-
- class Embedding(nn.Cell):
- def __init__(self, inp, oup=128, S=8):
- super(Embedding, self).__init__()
- self.gconv1 = g_conv(inp, 1024, k=1, s=1, pad=0, groups=1, bias=False)
- self.bn = nn.BatchNorm2d(1024)
- self.relu = nn.ReLU6()
- self.gconv2 = g_conv(1024, 1024, 7, 1, pad=0, groups=1024 // 8, bias=False)
- self.gconv3 = g_conv(1024, 512, 1, 1, pad=0, groups=512, bias=False)
- self.fc = fc_with_initialize(input_channels=512, out_channels=oup)
-
- def construct(self, x):
- x = self.gconv1(x)
- x = self.bn(x)
- x = self.relu(x)
- x = self.gconv2(x)
- x = self.gconv3(x)
- x = F.reshape(x, (x.shape[0], -1))
- out = self.fc(x)
- return out
-
-
- class VarGFaceNet(nn.Cell):
- def __init__(self, num_classes=512):
- super(VarGFaceNet, self).__init__()
-
- self.conv1 = nn.SequentialCell(
- g_conv(inp=3, oup=40, k=3, s=1, pad=1, groups=1, bias=False),
- nn.BatchNorm2d(40),
- nn.ReLU6()
- )
- self.head = HeadSetting(40, 3)
- self.stage2 = nn.SequentialCell(
- DownSampling(40, 3, 2),
- NormalBlock(80, 3, 1),
- NormalBlock(80, 3, 1)
- )
- self.stage3 = nn.SequentialCell(
- DownSampling(80, 3, 2),
- NormalBlock(160, 3, 1),
- NormalBlock(160, 3, 1),
- NormalBlock(160, 3, 1),
- NormalBlock(160, 3, 1),
- NormalBlock(160, 3, 1),
- NormalBlock(160, 3, 1),
- )
- self.stage4 = nn.SequentialCell(
- DownSampling(160, 3, 2),
- NormalBlock(320, 3, 1),
- NormalBlock(320, 3, 1),
- NormalBlock(320, 3, 1),
- )
- self.embedding = Embedding(320, 128)
-
- def construct(self, x):
- x = self.conv1(x)
- x = self.head(x)
- x = self.stage2(x)
- x = self.stage3(x)
- x = self.stage4(x)
- out = self.embedding(x)
- return out
-
-
- class CosMarginProduct(nn.Cell):
- def __init__(self, in_features=128, out_features=200, s=32.0, m=0.50, easy_margin=False):
- super(CosMarginProduct, self).__init__()
- self.in_features = in_features
- self.out_features = out_features
- self.s = s
- self.m = m
- shape = (out_features, in_features)
- self.weight = Parameter(initializer('XavierUniform', shape=shape, dtype=mstype.float32), name='weight')
- self.matmul = P.MatMul(transpose_b=True)
- self.one_hot = P.OneHot()
- self.on_value = Tensor(1.0, mstype.float32)
- self.off_value = Tensor(0.0, mstype.float32)
- self.l2_norm = P.L2Normalize(axis=1)
-
- def construct(self, x, label):
- # cosine = self.matmul(self.l2_norm(x), self.l2_norm(self.weight))
- cosine = self.matmul(self.l2_norm(x).astype(np.float16), self.l2_norm(self.weight).astype(np.float16))
- phi = cosine - self.m
- one_hot = self.one_hot(label, phi.shape[1], self.on_value, self.off_value)
- output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
- output *= self.s
- return output
-
-
- class WholeNet(nn.Cell):
-
- def __init__(self, train_phase=True, num_class=10571, num_s=32.0, num_m=0.50):
- super(WholeNet, self).__init__()
- self.train_phase = train_phase
- self.backbone = VarGFaceNet()
- self.product = CosMarginProduct(out_features=num_class, s=num_s, m=num_m)
-
- def construct(self, x, y):
- x = self.backbone(x)
- if not self.train_phase:
- output = x
- else:
- output = self.product(x, y)
- return output
|