|
- """
- MindSpore implementation of `RepMLP`.
- Refer to RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality.
- """
-
- import numpy as np
- from collections import OrderedDict
-
- from mindspore import nn, ops, Tensor
- import mindspore.common.initializer as init
-
- from .registry import register_model
- from .utils import load_pretrained
-
-
- __all__ = [
- "RepMLPNet",
- "RepMLPNet_T224",
- "RepMLPNet_T256",
- "RepMLPNet_B224",
- "RepMLPNet_B256",
- "RepMLPNet_D256",
- "RepMLPNet_L256"
- ]
-
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'first_conv': 'features.0', 'classifier': 'classifier',
- **kwargs
- }
-
- default_cfgs = {
- 'RepMLPNet_T224': _cfg(url=''),
- 'RepMLPNet_T256': _cfg(url='',input_size=(3, 256, 256)),
- 'RepMLPNet_B224': _cfg(url=''),
- 'RepMLPNet_B256': _cfg(url='',input_size=(3, 256, 256)),
- 'RepMLPNet_D256': _cfg(url='',input_size=(3, 256, 256)),
- 'RepMLPNet_L256': _cfg(url='',input_size=(3, 256, 256)),
- }
-
- def conv_bn(in_channels, out_channels, kernel_size, stride, padding, group=1, momentum=0.9, has_bias=False):
- d = OrderedDict()
- conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
- kernel_size=kernel_size, stride=stride, pad_mode="pad", padding=padding, group=group,
- has_bias=has_bias)
- bn1 = nn.BatchNorm2d(num_features=out_channels).set_train()
- d['conv'] = conv1
- d['bn'] = bn1
- result = nn.SequentialCell(d)
- return result
-
-
- def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, group=1, has_bias=False):
- d = OrderedDict()
- conv2 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
- padding=padding, group=group, has_bias=False)
- relu = nn.ReLU()
- d['conv'] = conv2
- d['relu'] = relu
- result = nn.SequentialCell(d)
- return result
-
-
- def fuse_bn(conv_or_fc, bn):
- std = (bn.running_var + bn.eps).sqrt()
- t = bn.weight / std
- t = t.reshape(-1, 1, 1, 1)
-
- if len(t) == conv_or_fc.weight.size(0):
- return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std
- else:
- repeat_times = conv_or_fc.weight.size(0) // len(t)
- repeated = t.repeat_interleave(repeat_times, 0)
- return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave(
- repeat_times, 0)
-
-
- class GlobalPerceptron(nn.Cell):
- """GlobalPerceptron Layers povides global informations(One of the three components of RepMLPBlock)"""
- def __init__(self, input_channels, internal_neurons):
- super(GlobalPerceptron, self).__init__()
- self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=(1, 1), stride=1,
- has_bias=True)
- self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=(1, 1), stride=1,
- has_bias=True)
-
- self.relu = nn.ReLU()
- self.sigmoid = nn.Sigmoid()
- self.input_channels = input_channels
- self.shape = ops.Shape()
-
- def construct(self, x):
- shape = self.shape(x)
- pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3]), stride=1)
- x = pool(x)
- x = self.fc1(x)
- x = self.relu(x)
- x = self.fc2(x)
- x = self.sigmoid(x)
- x = x.view(-1, self.input_channels, 1, 1)
- return x
-
-
- class RepMLPBlock(nn.Cell):
- """Basic RepMLPBlock Layer(compose of Global Perceptron, Channel Perceptron and Local Perceptron)"""
- def __init__(self, in_channels, out_channels,
- h, w,
- reparam_conv_k=None,
- globalperceptron_reduce=4,
- num_sharesets=1,
- deploy=False):
- super().__init__()
-
- self.C = in_channels
- self.O = out_channels
- self.S = num_sharesets
-
- self.h, self.w = h, w
-
- self.deploy = deploy
- self.transpose = ops.Transpose()
- self.shape = ops.Shape()
- self.reshape = ops.Reshape()
-
- assert in_channels == out_channels
- self.gp = GlobalPerceptron(input_channels=in_channels, internal_neurons=in_channels // globalperceptron_reduce)
-
- self.fc3 = nn.Conv2d(in_channels=self.h * self.w * num_sharesets, out_channels=self.h * self.w * num_sharesets,
- kernel_size=(1, 1),
- stride=1, padding=0, has_bias=deploy, group=num_sharesets)
- if deploy:
- self.fc3_bn = ops.Identity()
- else:
- self.fc3_bn = nn.BatchNorm2d(num_sharesets).set_train()
-
- self.reparam_conv_k = reparam_conv_k
- self.conv_branch_k = []
- if not deploy and reparam_conv_k is not None:
- for k in reparam_conv_k:
- conv_branch = conv_bn(num_sharesets, num_sharesets, kernel_size=k, stride=1, padding=k // 2,
- group=num_sharesets, momentum=0.9, has_bias=False)
- self.__setattr__('repconv{}'.format(k), conv_branch)
- self.conv_branch_k.append(conv_branch)
- # print(conv_branch)
-
- def partition(self, x, h_parts, w_parts):
- x = x.reshape(-1, self.C, h_parts, self.h, w_parts, self.w)
- input_perm = (0, 2, 4, 1, 3, 5)
- x = self.transpose(x, input_perm)
- return x
-
- def partition_affine(self, x, h_parts, w_parts):
- fc_inputs = x.reshape(-1, self.S * self.h * self.w, 1, 1)
- out = self.fc3(fc_inputs)
- out = out.reshape(-1, self.S, self.h, self.w)
- out = self.fc3_bn(out)
- out = out.reshape(-1, h_parts, w_parts, self.S, self.h, self.w)
- return out
-
- def construct(self, inputs):
- # Global Perceptron
- global_vec = self.gp(inputs)
-
- origin_shape = self.shape(inputs)
-
- h_parts = origin_shape[2] // self.h
- w_parts = origin_shape[3] // self.w
-
- partitions = self.partition(inputs, h_parts, w_parts)
-
- # Channel Perceptron
- fc3_out = self.partition_affine(partitions, h_parts, w_parts)
-
- # Local Perceptron
- if self.reparam_conv_k is not None and not self.deploy:
- conv_inputs = self.reshape(partitions, (-1, self.S, self.h, self.w))
- conv_out = 0
- for k in self.conv_branch_k:
- conv_out += k(conv_inputs)
- conv_out = self.reshape(conv_out, (-1, h_parts, w_parts, self.S, self.h, self.w))
- fc3_out += conv_out
-
- input_perm = (0, 3, 1, 4, 2, 5)
- fc3_out = self.transpose(fc3_out, input_perm) # N, O, h_parts, out_h, w_parts, out_w
- out = fc3_out.reshape(*origin_shape)
- out = out * global_vec
- return out
-
- def get_equivalent_fc3(self):
- fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn)
- if self.reparam_conv_k is not None:
- largest_k = max(self.reparam_conv_k)
- largest_branch = self.__getattr__('repconv{}'.format(largest_k))
- total_kernel, total_bias = fuse_bn(largest_branch.conv, largest_branch.bn)
- for k in self.reparam_conv_k:
- if k != largest_k:
- k_branch = self.__getattr__('repconv{}'.format(k))
- kernel, bias = fuse_bn(k_branch.conv, k_branch.bn)
- total_kernel += nn.Pad(kernel, [(largest_k - k) // 2] * 4)
- total_bias += bias
- rep_weight, rep_bias = self._convert_conv_to_fc(total_kernel, total_bias)
- final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight
- final_fc3_bias = rep_bias + fc_bias
- else:
- final_fc3_weight = fc_weight
- final_fc3_bias = fc_bias
- return final_fc3_weight, final_fc3_bias
-
- def local_inject(self):
- self.deploy = True
- # Locality Injection
- fc3_weight, fc3_bias = self.get_equivalent_fc3()
- # Remove Local Perceptron
- if self.reparam_conv_k is not None:
- for k in self.reparam_conv_k:
- self.__delattr__('repconv{}'.format(k))
- self.__delattr__('fc3')
- self.__delattr__('fc3_bn')
- self.fc3 = nn.Conv2d(self.S * self.h * self.w, self.S * self.h * self.w, 1, 1, 0, has_bias=True, group=self.S)
- self.fc3_bn = ops.Identity()
- self.fc3.weight.data = fc3_weight
- self.fc3.bias.data = fc3_bias
-
- def _convert_conv_to_fc(self, conv_kernel, conv_bias):
- I = ops.eye(self.h * self.w).repeat(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w).to(
- conv_kernel.device)
- fc_k = ops.Conv2D(I, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S)
- fc_k = fc_k.reshape(self.h * self.w, self.S * self.h * self.w).t()
- fc_bias = conv_bias.repeat_interleave(self.h * self.w)
- return fc_k, fc_bias
-
-
- class FFNBlock(nn.Cell):
- """Common FFN layer"""
- def __init__(self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU):
- super().__init__()
- out_features = out_channels or in_channels
- hidden_features = hidden_channels or in_channels
- self.ffn_fc1 = conv_bn(in_channels, hidden_features, 1, 1, 0, has_bias=False)
- self.ffn_fc2 = conv_bn(hidden_features, out_features, 1, 1, 0, has_bias=False)
- self.act = act_layer()
-
- def construct(self, inputs):
- x = self.ffn_fc1(inputs)
- x = self.act(x)
- x = self.ffn_fc2(x)
- return x
-
-
- class RepMLPNetUnit(nn.Cell):
- """Basic unit of RepMLPNet"""
- def __init__(self, channels, h, w, reparam_conv_k, globalperceptron_reduce, ffn_expand=4,
- num_sharesets=1, deploy=False):
- super().__init__()
- self.repmlp_block = RepMLPBlock(in_channels=channels, out_channels=channels, h=h, w=w,
- reparam_conv_k=reparam_conv_k, globalperceptron_reduce=globalperceptron_reduce,
- num_sharesets=num_sharesets, deploy=deploy)
- self.ffn_block = FFNBlock(channels, channels * ffn_expand)
- self.prebn1 = nn.BatchNorm2d(channels).set_train()
- self.prebn2 = nn.BatchNorm2d(channels).set_train()
-
- def construct(self, x):
- y = x + self.repmlp_block(self.prebn1(x))
- # print(y)
- z = y + self.ffn_block(self.prebn2(y))
- return z
-
-
- class RepMLPNet(nn.Cell):
- r"""RepMLPNet model class, based on
- `"RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality" <https://arxiv.org/pdf/2112.11081v2.pdf>`_
-
- Args:
- in_channels: number of input channels. Default: 3.
- num_classes: number of classification classes. Default: 1000.
- patch_size: size of a single image patch. Default: (4, 4)
- num_blocks: number of blocks per stage. Default: (2,2,6,2)
- channels: number of in_channels(channels[stage_idx]) and out_channels(channels[stage_idx + 1]) per stage. Default: (192,384,768,1536)
- hs: height of picture per stage. Default: (64,32,16,8)
- ws: width of picture per stage. Default: (64,32,16,8)
- sharesets_nums: number of share sets per stage. Default: (4,8,16,32)
- reparam_conv_k: convolution kernel size in local Perceptron. Default: (3,)
- globalperceptron_reduce: Intermediate convolution output size(in_channal = inchannal, out_channel = in_channel/globalperceptron_reduce)
- in globalperceptron. Default: 4
- use_checkpoint: whether to use checkpoint
- deploy: whether to use bias
- """
- def __init__(self,
- in_channels=3, num_class=1000,
- patch_size=(4, 4),
- num_blocks=(2, 2, 6, 2), channels=(192, 384, 768, 1536),
- hs=(64, 32, 16, 8), ws=(64, 32, 16, 8),
- sharesets_nums=(4, 8, 16, 32),
- reparam_conv_k=(3,),
- globalperceptron_reduce=4, use_checkpoint=False,
- deploy=False):
- super().__init__()
- num_stages = len(num_blocks)
- assert num_stages == len(channels)
- assert num_stages == len(hs)
- assert num_stages == len(ws)
- assert num_stages == len(sharesets_nums)
-
- self.conv_embedding = conv_bn_relu(in_channels, channels[0], kernel_size=patch_size, stride=patch_size,
- padding=0, has_bias=False)
- self.conv2d = nn.Conv2d(in_channels, channels[0], kernel_size=patch_size, stride=patch_size, padding=0)
-
- stages = []
- embeds = []
- for stage_idx in range(num_stages):
- stage_blocks = [RepMLPNetUnit(channels=channels[stage_idx], h=hs[stage_idx], w=ws[stage_idx],
- reparam_conv_k=reparam_conv_k,
- globalperceptron_reduce=globalperceptron_reduce, ffn_expand=4,
- num_sharesets=sharesets_nums[stage_idx],
- deploy=deploy) for _ in range(num_blocks[stage_idx])]
- stages.append(nn.CellList(stage_blocks))
- if stage_idx < num_stages - 1:
- embeds.append(
- conv_bn_relu(in_channels=channels[stage_idx], out_channels=channels[stage_idx + 1], kernel_size=2,
- stride=2, padding=0))
- self.stages = nn.CellList(stages)
- self.embeds = nn.CellList(embeds)
- self.head_norm = nn.BatchNorm2d(channels[-1]).set_train()
- self.head = nn.Dense(channels[-1], num_class)
-
- self.use_checkpoint = use_checkpoint
- self.shape = ops.Shape()
- self.reshape = ops.Reshape()
- self._initialize_weights()
-
- def _initialize_weights(self):
- """Initialize weights for cells."""
- for name, cell in self.cells_and_names():
- if isinstance(cell, nn.Conv2d):
- k = cell.group / (cell.in_channels * cell.kernel_size[0] * cell.kernel_size[1])
- k = k ** 0.5
-
- cell.weight.set_data(
- init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(
- init.initializer(init.Uniform(k), cell.bias.shape, cell.bias.dtype))
- elif isinstance(cell, nn.Dense):
- k = 1 / cell.in_channels
- k = k ** 0.5
-
- cell.weight.set_data(
- init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(
- init.initializer(init.Uniform(k), cell.bias.shape, cell.bias.dtype))
-
- def forward_features(self, x: Tensor) -> Tensor:
- x = self.conv_embedding(x)
-
- for i, stage in enumerate(self.stages):
- for block in stage:
- x = block(x)
-
- if i < len(self.stages) - 1:
- embed = self.embeds[i]
- x = embed(x)
- x = self.head_norm(x)
- shape = self.shape(x)
- pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3]))
- x = pool(x)
- return x.view(shape[0], -1)
-
- def forward_head(self, x: Tensor)-> Tensor:
- return self.head(x)
-
-
- def construct(self, x: Tensor) -> Tensor:
- x = self.forward_features(x)
- return self.forward_head(x)
-
-
- def locality_injection(self):
- for m in self.modules():
- if hasattr(m, 'local_inject'):
- m.local_inject()
-
- @register_model
- def RepMLPNet_T224(pretrained: bool = False, image_size: int = 224, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_T224 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_T224']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(64, 128, 256, 512), hs=(56,28,14,7), ws=(56,28,14,7),
- num_blocks=(2,2,6,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128),
- deploy=deploy)
-
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
- @register_model
- def RepMLPNet_T256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_T256 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_T256']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(64, 128, 256, 512), hs=(64,32,16,8), ws=(64,32,16,8),
- num_blocks=(2,2,6,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128),
- deploy=deploy)
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
- @register_model
- def RepMLPNet_B224(pretrained: bool = False, image_size: int = 224, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_B224 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_B224']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(56,28,14,7), ws=(56,28,14,7),
- num_blocks=(2,2,12,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,128),
- deploy=deploy)
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
- @register_model
- def RepMLPNet_B256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_B256 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_B256']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(64,32,16,8), ws=(64,32,16,8),
- num_blocks=(2,2,12,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,128),
- deploy=deploy)
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
- @register_model
- def RepMLPNet_D256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_D256 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_D256']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(80, 160, 320, 640), hs=(64,32,16,8), ws=(64,32,16,8),
- num_blocks=(2,2,18,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128),
- deploy=deploy)
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
- @register_model
- def RepMLPNet_L256(pretrained: bool = False, image_size: int = 256, num_classes: int = 1000, in_channels=3, deploy=False, **kwargs):
- """Get RepMLPNet_L256 model.
- Refer to the base class `models.RepMLPNet` for more details."""
- default_cfg = default_cfgs['RepMLPNet_L256']
- model = RepMLPNet(in_channels=in_channels, num_class=num_classes, channels=(96, 192, 384, 768), hs=(64,32,16,8), ws=(64,32,16,8),
- num_blocks=(2,2,18,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,256),
- deploy=deploy)
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
- model.image_size = image_size
- return model
-
-
- # Verify the equivalency
- if __name__ == '__main__':
- # x = Tensor(np.ones([1, 3, 3, 3]).astype(np.float32))
- x = Tensor(np.ones([1, 3, 256, 256]).astype(np.float32))
- model = RepMLPNet_B256()
- #model = GlobalPerceptron(input_channels=96, internal_neurons=54)
- #model = RepMLPBlock(in_channels=96, out_channels=96, h=56, w=56, reparam_conv_k=(1,3), num_sharesets=4)
- # model = RepMLPNetUnit(channels=96, h=56, w=56, reparam_conv_k=(1, 3), globalperceptron_reduce=4, ffn_expand=4, num_sharesets=1, deploy=False)
-
-
- origin_y = model(x)
-
-
- # model.locality_injection()
-
- print(model)
- # new_y = model(x)
- # print((new_y - origin_y).abs().sum())
- print(origin_y)
- #print(origin_y.shape)
|