|
-
- from typing import Optional, Sequence, Tuple, Union
- import mindspore.nn as nn
-
- from monai.networks.blocks import ADN
- from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding
- from monai.networks.layers.factories import Conv
-
-
- class Convolution(nn.Cell):
-
- def __init__(
- self,
- spatial_dims: int,
- in_channels: int,
- out_channels: int,
- strides: Union[Sequence[int], int] = 1,
- kernel_size: Union[Sequence[int], int] = 3,
- adn_ordering: str = "NDA",
- act: Optional[Union[Tuple, str]] = "PRELU",
- norm: Optional[Union[Tuple, str]] = "INSTANCE",
- dropout: Optional[Union[Tuple, str, float]] = None,
- dropout_dim: Optional[int] = 1,
- dilation: Union[Sequence[int], int] = 1,
- groups: int = 1,
- bias: bool = True,
- conv_only: bool = False,
- is_transposed: bool = False,
- padding: Optional[Union[Sequence[int], int]] = None,
- output_padding: Optional[Union[Sequence[int], int]] = None,
- dimensions: Optional[int] = None,
- ) -> None:
- super().__init__()
- self.dimensions = spatial_dims if dimensions is None else dimensions
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.is_transposed = is_transposed
- if padding is None:
- padding = same_padding(kernel_size, dilation)
- conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.dimensions]
-
- conv: nn.Layer
- if is_transposed:
- if output_padding is None:
- output_padding = stride_minus_kernel_padding(1, strides)
- conv = conv_type(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=strides,
- padding=padding,
- output_padding=output_padding,
- groups=groups,
- bias_attr=bias,
- dilation=dilation,
- )
- else:
- conv = conv_type(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- bias_attr=bias,
- )
-
- self.add_sublayer("conv", conv)
-
- if not conv_only:
- self.add_sublayer(
- "adn",
- ADN(
- ordering=adn_ordering,
- in_channels=out_channels,
- act=act,
- norm=norm,
- norm_dim=self.dimensions,
- dropout=dropout,
- dropout_dim=dropout_dim,
- ),
- )
|