|
- """
- MindSpore implementation of Xception.
- Refer to Xception: Deep Learning with Depthwise Separable Convolutions.
- """
-
- from mindspore import nn, ops, Tensor
- import mindspore.common.initializer as init
-
- from mindcv.models.registry import register_model
- from mindcv.models.utils import load_pretrained
- from mindcv.models.layers import GlobalAvgPooling
-
- __all__ = [
- 'Xception',
- 'xception'
- ]
-
-
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'first_conv': 'conv1', 'classifier': 'classifier',
- **kwargs
- }
-
-
- default_cfgs = {
- 'xception': _cfg(url='')
- }
-
-
- class SeparableConv2d(nn.Cell):
- '''SeparableCon2d module of Xception'''
-
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: int = 1,
- stride: int = 1,
- padding: int = 0):
- super().__init__()
- self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad',
- padding=padding)
- self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid')
-
- def construct(self, x):
- x = self.conv1(x)
- x = self.pointwise(x)
- return x
-
-
- class Block(nn.Cell):
- '''Basic module of Xception'''
- def __init__(self,
- in_filters: int,
- out_filters: int,
- reps: int,
- strides: int = 1,
- start_with_relu: bool = True,
- grow_first: bool = True):
- super().__init__()
-
- if out_filters != in_filters or strides != 1:
- self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False)
- self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9)
- else:
- self.skip = None
-
- self.relu = nn.ReLU()
- rep = []
- filters = in_filters
- if grow_first:
- rep.append(nn.ReLU())
- rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
- rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
- filters = out_filters
-
- for _ in range(reps - 1):
- rep.append(nn.ReLU())
- rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1))
- rep.append(nn.BatchNorm2d(filters, momentum=0.9))
-
- if not grow_first:
- rep.append(nn.ReLU())
- rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
- rep.append(nn.BatchNorm2d(out_filters, momentum=0.9))
-
- if not start_with_relu:
- rep = rep[1:]
- else:
- rep[0] = nn.ReLU()
-
- if strides != 1:
- rep.append(nn.MaxPool2d(3, strides, pad_mode="same"))
- self.rep = nn.SequentialCell(*rep)
-
- def construct(self, inp):
- x = self.rep(inp)
-
- if self.skip is not None:
- skip = self.skip(inp)
- skip = self.skipbn(skip)
- else:
- skip = inp
- x = ops.add(x, skip)
- return x
-
-
- class Xception(nn.Cell):
- r"""Xception model architecture from
- `"Deep Learning with Depthwise Separable Convolutions" <https://arxiv.org/abs/1610.02357>`_.
-
- Args:
- num_classes (int) : number of classification classes. Default: 1000.
- in_channels (int): number the channels of the input. Default: 3.
- """
-
- def __init__(self,
- num_classes: int = 1000,
- in_channels: int = 3):
- super().__init__()
- self.num_classes = num_classes
- blocks = []
- self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, pad_mode='valid')
- self.bn1 = nn.BatchNorm2d(32, momentum=0.9)
- self.relu = nn.ReLU()
- self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid')
- self.bn2 = nn.BatchNorm2d(64, momentum=0.9)
-
- # Entry flow
- blocks.append(Block(64, 128, 2, 2, start_with_relu=False, grow_first=True))
- blocks.append(Block(128, 256, 2, 2, start_with_relu=True, grow_first=True))
- blocks.append(Block(256, 728, 2, 2, start_with_relu=True, grow_first=True))
-
- # Middle flow
- for _ in range(8):
- blocks.append(Block(728, 728, 3, 1, start_with_relu=True, grow_first=True))
-
- # Exit flow
- blocks.append(Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False))
-
- self.blocks = nn.SequentialCell(blocks)
-
- self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(1536, momentum=0.9)
- self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
- self.bn4 = nn.BatchNorm2d(2048, momentum=0.9)
-
- self.pool = GlobalAvgPooling()
- self.dropout = nn.Dropout()
- self.classifier = nn.Dense(2048, num_classes)
-
- self._initialize_weights()
-
- def forward_features(self, x: Tensor) -> Tensor:
- """forward the backbone of Xception"""
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
- x = self.blocks(x)
- x = self.conv3(x)
- x = self.bn3(x)
- x = self.relu(x)
- x = self.conv4(x)
- x = self.bn4(x)
- x = self.relu(x)
- return x
-
- def forward_head(self, x: Tensor) -> Tensor:
- x = self.pool(x)
- x = self.dropout(x)
- x = self.classifier(x)
- return x
-
- def construct(self, x: Tensor) -> Tensor:
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
-
- def _initialize_weights(self) -> None:
- """Initialize weights for cells."""
- for _, cell in self.cells_and_names():
- if isinstance(cell, nn.Conv2d):
- cell.weight.set_data(
- init.initializer(init.XavierUniform(),
- cell.weight.shape, cell.weight.dtype))
- elif isinstance(cell, nn.Dense):
- cell.weight.set_data(init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype))
- if cell.bias is not None:
- cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.weight.dtype))
-
-
- @register_model
- def xception(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs) -> Xception:
- """Get Xception model.
- Refer to the base class `models.Xception` for more details."""
- default_cfg = default_cfgs['xception']
- model = Xception(num_classes=num_classes, in_channels=in_channels, **kwargs)
-
- if pretrained:
- load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
-
- return model
|