|
- from collections import OrderedDict
- from functools import partial
- from typing import Callable, Optional
-
- import torch.nn as nn
- import torch
- from torch import Tensor
- import torch.utils.checkpoint as checkpoint
- import sys
- import os
-
- """
- 将上采样,拼接,写到一个模块里边
- """
-
-
- def drop_path(x, drop_prob: float = 0., training: bool = False):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
-
- This function is taken from the rwightman.
- It can be seen here:
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L140
- """
- if drop_prob == 0. or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- random_tensor.floor_() # binarize
- output = x.div(keep_prob) * random_tensor
- return output
-
-
- class DropPath(nn.Module):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
- """
-
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
-
-
- class ConvBNAct(nn.Module):
- def __init__(self,
- in_planes: int,
- out_planes: int,
- kernel_size: int = 3,
- stride: int = 1,
- groups: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- activation_layer: Optional[Callable[..., nn.Module]] = None):
- super(ConvBNAct, self).__init__()
-
- padding = (kernel_size - 1) // 2
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- if activation_layer is None:
- activation_layer = nn.SiLU # alias Swish (torch>=1.7)
-
- self.conv = nn.Conv2d(in_channels=in_planes,
- out_channels=out_planes,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- groups=groups,
- bias=False)
-
- self.bn = norm_layer(out_planes)
- self.act = activation_layer()
-
- def forward(self, x):
- result = self.conv(x)
- result = self.bn(result)
- result = self.act(result)
-
- return result
-
-
- class SqueezeExcite(nn.Module):
- def __init__(self,
- input_c: int, # block input channel
- expand_c: int, # block expand channel
- se_ratio: float = 0.25):
- super(SqueezeExcite, self).__init__()
- squeeze_c = int(input_c * se_ratio)
- self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1)
- self.act1 = nn.SiLU() # alias Swish
- self.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1)
- self.act2 = nn.Sigmoid()
-
- def forward(self, x: Tensor) -> Tensor:
- scale = x.mean((2, 3), keepdim=True)
- scale = self.conv_reduce(scale)
- scale = self.act1(scale)
- scale = self.conv_expand(scale)
- scale = self.act2(scale)
- return scale * x
-
-
- class MBConv(nn.Module):
- def __init__(self,
- kernel_size: int,
- input_c: int,
- out_c: int,
- expand_ratio: int,
- stride: int,
- se_ratio: float,
- drop_rate: float,
- norm_layer: Callable[..., nn.Module]):
- super(MBConv, self).__init__()
-
- if stride not in [1, 2]:
- raise ValueError("illegal stride value.")
-
- self.has_shortcut = (stride == 1 and input_c == out_c)
-
- activation_layer = nn.SiLU # alias Swish
- expanded_c = input_c * expand_ratio
-
- # 在EfficientNetV2中,MBConv中不存在expansion=1的情况所以conv_pw肯定存在
- assert expand_ratio != 1
- # Point-wise expansion
- self.expand_conv = ConvBNAct(input_c,
- expanded_c,
- kernel_size=1,
- norm_layer=norm_layer,
- activation_layer=activation_layer)
-
- # Depth-wise convolution
- self.dwconv = ConvBNAct(expanded_c,
- expanded_c,
- kernel_size=kernel_size,
- stride=stride,
- groups=expanded_c,
- norm_layer=norm_layer,
- activation_layer=activation_layer)
-
- self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else nn.Identity()
-
- # Point-wise linear projection
- self.project_conv = ConvBNAct(expanded_c,
- out_planes=out_c,
- kernel_size=1,
- norm_layer=norm_layer,
- activation_layer=nn.Identity) # 注意这里没有激活函数,所有传入Identity
-
- self.out_channels = out_c
-
- # 只有在使用shortcut连接时才使用dropout层
- self.drop_rate = drop_rate
- if self.has_shortcut and drop_rate > 0:
- self.dropout = DropPath(drop_rate)
-
- def forward(self, x: Tensor) -> Tensor:
- result = self.expand_conv(x)
- result = self.dwconv(result)
- result = self.se(result)
- result = self.project_conv(result)
-
- if self.has_shortcut:
- if self.drop_rate > 0:
- result = self.dropout(result)
- result += x
-
- return result
-
-
- class FusedMBConv(nn.Module):
- def __init__(self,
- kernel_size: int,
- input_c: int,
- out_c: int,
- expand_ratio: int,
- stride: int,
- se_ratio: float,
- drop_rate: float,
- norm_layer: Callable[..., nn.Module]):
- super(FusedMBConv, self).__init__()
-
- assert stride in [1, 2]
- assert se_ratio == 0
-
- self.has_shortcut = stride == 1 and input_c == out_c
- self.drop_rate = drop_rate
-
- self.has_expansion = expand_ratio != 1
-
- activation_layer = nn.SiLU # alias Swish
- expanded_c = input_c * expand_ratio
-
- # 只有当expand ratio不等于1时才有expand conv
- if self.has_expansion:
- # Expansion convolution
- self.expand_conv = ConvBNAct(input_c,
- expanded_c,
- kernel_size=kernel_size,
- stride=stride,
- norm_layer=norm_layer,
- activation_layer=activation_layer)
-
- self.project_conv = ConvBNAct(expanded_c,
- out_c,
- kernel_size=1,
- norm_layer=norm_layer,
- activation_layer=nn.Identity) # 注意没有激活函数
- else:
- # 当只有project_conv时的情况
- self.project_conv = ConvBNAct(input_c,
- out_c,
- kernel_size=kernel_size,
- stride=stride,
- norm_layer=norm_layer,
- activation_layer=activation_layer) # 注意有激活函数
-
- self.out_channels = out_c
-
- # 只有在使用shortcut连接时才使用dropout层
- self.drop_rate = drop_rate
- if self.has_shortcut and drop_rate > 0:
- self.dropout = DropPath(drop_rate)
-
- def forward(self, x: Tensor) -> Tensor:
- if self.has_expansion:
- result = self.expand_conv(x)
- result = self.project_conv(result)
- else:
- result = self.project_conv(x)
-
- if self.has_shortcut:
- if self.drop_rate > 0:
- result = self.dropout(result)
-
- result += x
-
- return result
-
-
- class EfficientNetV2(nn.Module):
- def __init__(self,
- model_cnf: list,
- num_classes: int = 1000,
- num_features: int = 1280,
- dropout_rate: float = 0.2,
- drop_connect_rate: float = 0.2):
- super(EfficientNetV2, self).__init__()
-
- for cnf in model_cnf:
- assert len(cnf) == 8
-
- norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)
-
- stem_filter_num = model_cnf[0][4]
-
- self.stem = ConvBNAct(3,
- stem_filter_num,
- kernel_size=3,
- stride=1,
- norm_layer=norm_layer) # 激活函数默认是SiLU
-
- total_blocks = sum([i[0] for i in model_cnf])
- block_id = 0
- blocks = []
- for cnf in model_cnf:
- repeats = cnf[0]
- op = FusedMBConv if cnf[-2] == 0 else MBConv
- for i in range(repeats):
- blocks.append(op(kernel_size=cnf[1],
- input_c=cnf[4] if i == 0 else cnf[5],
- out_c=cnf[5],
- expand_ratio=cnf[3],
- stride=cnf[2] if i == 0 else 1,
- se_ratio=cnf[-1],
- drop_rate=drop_connect_rate * block_id / total_blocks,
- norm_layer=norm_layer))
- block_id += 1
- self.blocks = nn.Sequential(*blocks)
-
- large_kernel_sizes = [31, 29, 27, 13]
- layers = [2, 2, 18, 2]
- # channels = [24, 48, 64, 160]
- channels = [48, 64, 160, 256]
- dw_ratio = 0.3
- small_kernel = 3
- use_checkpoint = True
- small_kernel_merged = False
- norm_intermediate_features = False
-
- self.num_stages = len(layers)
- stages = []
- dpr = [x.item() for x in torch.linspace(0, dw_ratio, sum(layers))]
- for stage_idx in range(self.num_stages):
- layer = RepLKNetStage(channels=channels[stage_idx], num_blocks=layers[stage_idx],
- stage_lk_size=large_kernel_sizes[stage_idx],
- drop_path=dpr[sum(layers[:stage_idx]):sum(layers[:stage_idx + 1])],
- small_kernel=small_kernel, dw_ratio=1, ffn_ratio=4,
- use_checkpoint=use_checkpoint, small_kernel_merged=small_kernel_merged,
- norm_intermediate_features=norm_intermediate_features)
- stages.append(layer)
- self.stages = nn.Sequential(*stages)
- head_input_c = model_cnf[-1][-3]
- head = OrderedDict()
-
- head.update({"project_conv": ConvBNAct(head_input_c,
- num_features,
- kernel_size=1,
- norm_layer=norm_layer)}) # 激活函数默认是SiLU
-
- head.update({"avgpool": nn.AdaptiveAvgPool2d(1)})
- head.update({"flatten": nn.Flatten()})
-
- if dropout_rate > 0:
- head.update({"dropout": nn.Dropout(p=dropout_rate, inplace=True)})
- head.update({"classifier": nn.Linear(num_features, num_classes)})
-
- self.head = nn.Sequential(head)
-
- # initial weights
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out")
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.ones_(m.weight)
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.Linear):
- nn.init.normal_(m.weight, 0, 0.01)
- nn.init.zeros_(m.bias)
-
- def forward(self, x: Tensor) -> Tensor:
-
- """
- efficientnetv2_l+Unet
- """
- x = self.stem(x)
- # (1 24 512 512)
- feat1 = self.blocks[:2](x)
- # (1 48 256 256)
- feat2 = self.blocks[2:6](feat1)
- # (1 64 128 128)
- feat3 = self.blocks[6:10](feat2)
- # (1 160 64 64)
- feat4 = self.blocks[10:25](feat3)
- # (1 256 32 32)
- feat5 = self.blocks[25:-1](feat4)
-
- # feat2 = self.stages[0](feat2)
- # feat3 = self.stages[1](feat3)
- # feat4 = self.stages[2](feat4)
- # feat5 = self.stages[3](feat5)
- return [feat1, feat2, feat3, feat4, feat5]
-
-
- def fuse_bn(conv, bn):
- kernel = conv.weight
- running_mean = bn.running_mean
- running_var = bn.running_var
- gamma = bn.weight
- beta = bn.bias
- eps = bn.eps
- std = (running_var + eps).sqrt()
- t = (gamma / std).reshape(-1, 1, 1, 1)
- return kernel * t, beta - running_mean * gamma / std
-
-
- def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
- if padding is None:
- padding = kernel_size // 2
- result = nn.Sequential()
- result.add_module('conv', get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
- stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False))
- result.add_module('bn', get_bn(out_channels))
- return result
-
-
- def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias):
- if type(kernel_size) is int:
- use_large_impl = kernel_size > 5
- else:
- assert len(kernel_size) == 2 and kernel_size[0] == kernel_size[1]
- use_large_impl = kernel_size[0] > 5
- has_large_impl = 'LARGE_KERNEL_CONV_IMPL' in os.environ
- if has_large_impl and in_channels == out_channels and out_channels == groups and use_large_impl and stride == 1 and padding == kernel_size // 2 and dilation == 1:
- sys.path.append(os.environ['LARGE_KERNEL_CONV_IMPL'])
- # Please follow the instructions https://github.com/DingXiaoH/RepLKNet-pytorch/blob/main/README.md
- # export LARGE_KERNEL_CONV_IMPL=absolute_path_to_where_you_cloned_the_example (i.e., depthwise_conv2d_implicit_gemm.py)
- # TODO more efficient PyTorch implementations of large-kernel convolutions. Pull requests are welcomed.
- # Or you may try MegEngine. We have integrated an efficient implementation into MegEngine and it will automatically use it.
- from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
- return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
- else:
- return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
- padding=padding, dilation=dilation, groups=groups, bias=bias)
-
-
- def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups, dilation=1):
- if padding is None:
- padding = kernel_size // 2
- result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
- stride=stride, padding=padding, groups=groups, dilation=dilation)
- result.add_module('nonlinear', nn.ReLU())
- return result
-
-
- class ReparamLargeKernelConv(nn.Module):
-
- def __init__(self, in_channels, out_channels, kernel_size,
- stride, groups,
- small_kernel,
- small_kernel_merged=False):
- super(ReparamLargeKernelConv, self).__init__()
- self.kernel_size = kernel_size
- self.small_kernel = small_kernel
- # We assume the conv does not change the feature map size, so padding = k//2. Otherwise, you may configure padding as you wish, and change the padding of small_conv accordingly.
- padding = kernel_size // 2
- if small_kernel_merged:
- self.lkb_reparam = get_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
- stride=stride, padding=padding, dilation=1, groups=groups, bias=True)
- else:
- self.lkb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
- stride=stride, padding=padding, dilation=1, groups=groups)
- if small_kernel is not None:
- assert small_kernel <= kernel_size, 'The kernel size for re-param cannot be larger than the large kernel!'
- self.small_conv = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=small_kernel,
- stride=stride, padding=small_kernel // 2, groups=groups, dilation=1)
-
- def forward(self, inputs):
- if hasattr(self, 'lkb_reparam'):
- out = self.lkb_reparam(inputs)
- else:
- out = self.lkb_origin(inputs)
- if hasattr(self, 'small_conv'):
- out += self.small_conv(inputs)
- return out
-
- def get_equivalent_kernel_bias(self):
- eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
- if hasattr(self, 'small_conv'):
- small_k, small_b = fuse_bn(self.small_conv.conv, self.small_conv.bn)
- eq_b += small_b
- # add to the central part
- eq_k += nn.functional.pad(small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
- return eq_k, eq_b
-
- def merge_kernel(self):
- eq_k, eq_b = self.get_equivalent_kernel_bias()
- self.lkb_reparam = get_conv2d(in_channels=self.lkb_origin.conv.in_channels,
- out_channels=self.lkb_origin.conv.out_channels,
- kernel_size=self.lkb_origin.conv.kernel_size, stride=self.lkb_origin.conv.stride,
- padding=self.lkb_origin.conv.padding, dilation=self.lkb_origin.conv.dilation,
- groups=self.lkb_origin.conv.groups, bias=True)
- self.lkb_reparam.weight.data = eq_k
- self.lkb_reparam.bias.data = eq_b
- self.__delattr__('lkb_origin')
- if hasattr(self, 'small_conv'):
- self.__delattr__('small_conv')
-
-
- def get_bn(channels):
- use_sync_bn = False
- if use_sync_bn:
- return nn.SyncBatchNorm(channels)
- else:
- return nn.BatchNorm2d(channels)
-
-
- class ConvFFN(nn.Module):
-
- def __init__(self, in_channels, internal_channels, out_channels, drop_path):
- super().__init__()
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.preffn_bn = get_bn(in_channels)
- self.pw1 = conv_bn(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, padding=0,
- groups=1)
- self.pw2 = conv_bn(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0,
- groups=1)
- self.nonlinear = nn.GELU()
-
- def forward(self, x):
- out = self.preffn_bn(x)
- out = self.pw1(out)
- out = self.nonlinear(out)
- out = self.pw2(out)
- return x + self.drop_path(out)
-
-
- class RepLKBlock(nn.Module):
-
- def __init__(self, in_channels, dw_channels, block_lk_size, small_kernel, drop_path, small_kernel_merged=False):
- super().__init__()
- self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
- self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
- self.large_kernel = ReparamLargeKernelConv(in_channels=dw_channels, out_channels=dw_channels,
- kernel_size=block_lk_size,
- stride=1, groups=dw_channels, small_kernel=small_kernel,
- small_kernel_merged=small_kernel_merged)
- self.lk_nonlinear = nn.ReLU()
- self.prelkb_bn = get_bn(in_channels)
- # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.drop_path = nn.Identity()
- print('drop path:', self.drop_path)
-
- def forward(self, x):
- out = self.prelkb_bn(x)
- out = self.pw1(out)
- out = self.large_kernel(out)
- out = self.lk_nonlinear(out)
- out = self.pw2(out)
- return x + self.drop_path(out)
-
-
- class RepLKNetStage(nn.Module):
-
- def __init__(self, channels, num_blocks, stage_lk_size, drop_path,
- small_kernel, dw_ratio=1, ffn_ratio=4,
- use_checkpoint=False, # train with torch.utils.checkpoint to save memory
- small_kernel_merged=False,
- norm_intermediate_features=False):
- super().__init__()
- self.use_checkpoint = use_checkpoint
- blks = []
- for i in range(num_blocks):
- block_drop_path = drop_path[i] if isinstance(drop_path, list) else drop_path
- # Assume all RepLK Blocks within a stage share the same lk_size. You may tune it on your own model.
- replk_block = RepLKBlock(in_channels=channels, dw_channels=int(channels * dw_ratio),
- block_lk_size=stage_lk_size,
- small_kernel=small_kernel, drop_path=block_drop_path,
- small_kernel_merged=small_kernel_merged)
- convffn_block = ConvFFN(in_channels=channels, internal_channels=int(channels * ffn_ratio),
- out_channels=channels,
- drop_path=block_drop_path)
- blks.append(replk_block)
- blks.append(convffn_block)
- self.blocks = nn.ModuleList(blks)
- if norm_intermediate_features:
- self.norm = get_bn(channels) # Only use this with RepLKNet-XL on downstream tasks
- else:
- self.norm = nn.Identity()
-
- def forward(self, x):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x) # Save training memory
- else:
- x = blk(x)
- return x
-
-
- class RepLKNet(nn.Module):
-
- def __init__(self, large_kernel_sizes, layers, channels, drop_path_rate, small_kernel,
- dw_ratio=1, ffn_ratio=4, in_channels=3, num_classes=1000, out_indices=None,
- use_checkpoint=False,
- small_kernel_merged=False,
- use_sync_bn=True,
- norm_intermediate_features=False
- # for RepLKNet-XL on COCO and ADE20K, use an extra BN to normalize the intermediate feature maps then feed them into the heads
- ):
- super().__init__()
-
-
- def efficientnetv2_s(num_classes: int = 1000):
- """
- EfficientNetV2
- https://arxiv.org/abs/2104.00298
- """
- # train_size: 300, eval_size: 384
-
- # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
- model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
- [4, 3, 2, 4, 24, 48, 0, 0],
- [4, 3, 2, 4, 48, 64, 0, 0],
- [6, 3, 2, 4, 64, 128, 1, 0.25],
- [9, 3, 1, 6, 128, 160, 1, 0.25],
- [15, 3, 2, 6, 160, 256, 1, 0.25]]
-
- model = EfficientNetV2(model_cnf=model_config,
- num_classes=num_classes,
- dropout_rate=0.2)
- return model
-
-
- def efficientnetv2_m(num_classes: int = 1000):
- """
- EfficientNetV2
- https://arxiv.org/abs/2104.00298
- """
- # train_size: 384, eval_size: 480
-
- # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
- model_config = [[3, 3, 1, 1, 24, 24, 0, 0],
- [5, 3, 2, 4, 24, 48, 0, 0],
- [5, 3, 2, 4, 48, 80, 0, 0],
- [7, 3, 2, 4, 80, 160, 1, 0.25],
- [14, 3, 1, 6, 160, 176, 1, 0.25],
- [18, 3, 2, 6, 176, 304, 1, 0.25],
- [5, 3, 1, 6, 304, 512, 1, 0.25]]
-
- model = EfficientNetV2(model_cnf=model_config,
- num_classes=num_classes,
- dropout_rate=0.3)
- return model
-
-
- def efficientnetv2_l(num_classes: int = 1000):
- """
- EfficientNetV2
- https://arxiv.org/abs/2104.00298
- """
- # train_size: 384, eval_size: 480
-
- # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
- model_config = [[4, 3, 1, 1, 32, 32, 0, 0],
- [7, 3, 2, 4, 32, 64, 0, 0],
- [7, 3, 2, 4, 64, 96, 0, 0],
- [10, 3, 2, 4, 96, 192, 1, 0.25],
- [19, 3, 1, 6, 192, 224, 1, 0.25],
- [25, 3, 2, 6, 224, 384, 1, 0.25],
- [7, 3, 1, 6, 384, 640, 1, 0.25]]
-
- model = EfficientNetV2(model_cnf=model_config,
- num_classes=num_classes,
- dropout_rate=0.4)
- return model
-
-
- class Convdown(nn.Module):
- def __init__(self, in_size, out_size):
- super(Convdown, self).__init__()
- self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
- # self.bn = nn.BatchNorm2d
- self.SiLU = nn.SiLU()
-
- def forward(self, x):
- x = self.conv1(x)
- # x = self.bn(x)
- x = self.SiLU(x)
- return x
-
-
- class unetUp(nn.Module):
- def __init__(self, in_size, out_size):
- super(unetUp, self).__init__()
- self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
- self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
- self.up = nn.UpsamplingBilinear2d(scale_factor=2)
- self.relu = nn.SiLU()
-
- def forward(self, inputs1, inputs2):
- outputs = torch.cat([inputs1, self.up(inputs2)], 1)
- outputs = self.conv1(outputs)
- outputs = self.relu(outputs)
- outputs = self.conv2(outputs)
- outputs = self.relu(outputs)
- return outputs
-
-
- class unetUp2(nn.Module):
- def __init__(self, in_size, out_size):
- super(unetUp2, self).__init__()
- self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
- self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
- self.up = nn.UpsamplingBilinear2d(scale_factor=2)
- self.relu = nn.SiLU()
-
- def forward(self, inputs1, inputs2):
- outputs = torch.cat([inputs1, self.up(inputs2)], 1)
- outputs = self.conv1(outputs)
- outputs = self.relu(outputs)
- outputs = self.conv2(outputs)
- outputs = self.relu(outputs)
- return outputs
-
-
- class AttentionModule(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
- # self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
- self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=3, groups=dim)
- self.conv1 = nn.Conv2d(dim, dim, 1)
-
- def forward(self, x):
- u = x.clone()
- attn = self.conv0(x)
- attn = self.conv_spatial(attn)
- attn = self.conv1(attn)
-
- return u * attn
-
-
- class SpatialAttention(nn.Module):
- def __init__(self, d_model):
- super().__init__()
-
- self.proj_1 = nn.Conv2d(d_model, d_model, 1)
- self.activation = nn.GELU()
- self.spatial_gating_unit = AttentionModule(d_model)
- self.proj_2 = nn.Conv2d(d_model, d_model, 1)
-
- def forward(self, x):
- shorcut = x.clone()
- x = self.proj_1(x)
- x = self.activation(x)
- x = self.spatial_gating_unit(x)
- x = self.proj_2(x)
- x = x + shorcut
- return x
-
-
- class EfficientNet_cd(nn.Module):
- def __init__(self, num_classes=21, pretrained=False, backbone='efficient'):
- super(EfficientNet_cd, self).__init__()
- self.efficient = efficientnetv2_s()
- self.up_concat4 = unetUp(416, 160)
-
- self.up = nn.UpsamplingBilinear2d(scale_factor=2)
- self.down1 = Convdown(160, 64)
- self.down2 = Convdown(128, 48)
- self.down3 = Convdown(96, 24)
- self.final = nn.Conv2d(24, num_classes, 1)
- self.up2 = unetUp2(128, 48)
- self.up3 = unetUp2(96, 24)
- self.backbone = backbone
-
- # RepLkANet
-
- def forward(self, inputs):
- inputs1 = inputs[:, 0:3, :, :]
- inputs2 = inputs[:, 3:, :, :]
- if self.backbone == "efficient":
- [feat1, feat2, feat3, feat4, feat5] = self.efficient.forward(inputs1)
- [feat21, feat22, feat32, feat42, feat52] = self.efficient.forward(inputs2)
-
- up4 = self.up_concat4(feat4, feat5)
-
- up42 = self.up_concat4(feat42, feat52)
-
- dis1 = torch.abs(up42 - up4)
- dis1 = self.down1(dis1)
-
- dis2 = torch.abs(feat32 - feat3)
- final1 = self.up2(dis2, dis1)
-
- dis3 = torch.abs(feat22 - feat2)
-
- """
- 将dis1,dis2,dis3进行拼接
- """
- final = self.up3(dis3, final1)
- final = self.up(final)
-
- final = self.final(final)
- """
- 将dis1,dis2,dis3 缩放到同一维度后进行相加
- """
- return final
|