|
- import warnings
- import torch.nn.functional as F
- from functools import partial
- from timm.models.layers import to_2tuple, trunc_normal_
- import math
- from timm.models.layers import DropPath
- from torch.nn import Module
- from mmcv.cnn import ConvModule
- from torch.nn import Conv2d, UpsamplingBilinear2d
- import torch.nn as nn
- import torch
- import os
- from mmcv.cnn import constant_init, kaiming_init
- from torchvision.transforms.functional import normalize
- warnings.filterwarnings('ignore')
-
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.dwconv = DWConv(hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x, H, W):
- x = self.fc1(x)
- x = self.dwconv(x, H, W)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
- class Attention(nn.Module):
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
- super().__init__()
- assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
-
- self.dim = dim
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
- self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- self.sr_ratio = sr_ratio
- if sr_ratio > 1:
- self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
- self.norm = nn.LayerNorm(dim)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- elif isinstance(m, nn.Conv2d):
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- fan_out //= m.groups
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
- if m.bias is not None:
- m.bias.data.zero_()
-
- def forward(self, x, H, W):
- B, N, C = x.shape
- q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
-
- if self.sr_ratio > 1:
- x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
- x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
- x_ = self.norm(x_)
- kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- else:
- kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- k, v = kv[0], kv[1]
-
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
-
- return x
-
-
- class Block(nn.Module):
-
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
- super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
- attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() #当drop_path=0时什么都不做,否则随机丢弃(正则化)
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) #mlp的初始化
-
- def forward(self, x, H, W):
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
- x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
-
- return x
-
-
- class OverlapPatchEmbed(nn.Module):
- """ Image to Patch Embedding
- """
-
- def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
-
- self.img_size = img_size
- self.patch_size = patch_size
- self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
- self.num_patches = self.H * self.W
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
- padding=(patch_size[0] // 2, patch_size[1] // 2))
- self.norm = nn.LayerNorm(embed_dim)
-
- def forward(self, x):
- x = self.proj(x)
- _, _, H, W = x.shape
- x = x.flatten(2).transpose(1, 2)
- x = self.norm(x)
-
- return x, H, W
-
-
- class MixVisionTransformer(nn.Module):
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
- num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
- attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
- depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
- super().__init__()
- self.num_classes = num_classes
- self.depths = depths
-
- # patch_embed
- self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
- embed_dim=embed_dims[0])
- self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
- embed_dim=embed_dims[1])
- self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
- embed_dim=embed_dims[2])
- self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
- embed_dim=embed_dims[3])
-
- # transformer encoder
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
- cur = 0
- self.block1 = nn.ModuleList([Block(
- dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[0])
- for i in range(depths[0])])
- self.norm1 = norm_layer(embed_dims[0])
-
- cur += depths[0]
- self.block2 = nn.ModuleList([Block(
- dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[1])
- for i in range(depths[1])])
- self.norm2 = norm_layer(embed_dims[1])
-
- cur += depths[1]
- self.block3 = nn.ModuleList([Block(
- dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[2])
- for i in range(depths[2])])
- self.norm3 = norm_layer(embed_dims[2])
-
- cur += depths[2]
- self.block4 = nn.ModuleList([Block(
- dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
- sr_ratio=sr_ratios[3])
- for i in range(depths[3])])
- self.norm4 = norm_layer(embed_dims[3])
-
- def forward_features(self, x):
- B = x.shape[0]
- outs = []
-
- # stage 1
- x, H, W = self.patch_embed1(x)
- for i, blk in enumerate(self.block1):
- x = blk(x, H, W)
- x = self.norm1(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 2
- x, H, W = self.patch_embed2(x)
- for i, blk in enumerate(self.block2):
- x = blk(x, H, W)
- x = self.norm2(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 3
- x, H, W = self.patch_embed3(x)
- for i, blk in enumerate(self.block3):
- x = blk(x, H, W)
- x = self.norm3(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- # stage 4
- x, H, W = self.patch_embed4(x)
- for i, blk in enumerate(self.block4):
- x = blk(x, H, W)
- x = self.norm4(x)
- x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
- outs.append(x)
-
- return outs
-
- def forward(self, x):
- x = self.forward_features(x)
-
- # x = self.head(x[3])
-
- return x
-
-
- class DWConv(nn.Module):
- def __init__(self, dim=768):
- super(DWConv, self).__init__()
- self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
-
- def forward(self, x, H, W):
- B, N, C = x.shape
- x = x.transpose(1, 2).view(B, C, H, W) #改變形狀
- x = self.dwconv(x)
- x = x.flatten(2).transpose(1, 2) #一張圖片進行展平
-
- return x
-
-
-
- '''
- class MixVisionTransformer(nn.Module):
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
- num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
- attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
- depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
- '''
-
- #实例化MixVisionTransformer
- class mit_b0(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b0, self).__init__(
- patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
- class mit_b1(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b1, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
- class mit_b2(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b2, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
- class mit_b3(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b3, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
- class mit_b4(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b4, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
- class mit_b5(MixVisionTransformer):
- def __init__(self, **kwargs):
- super(mit_b5, self).__init__(
- patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
- drop_rate=0.0, drop_path_rate=0.1)
-
-
-
-
- def resize(input,
- size=None,
- scale_factor=None,
- mode='nearest',
- align_corners=None,
- warning=True):
- if warning:
- if size is not None and align_corners:
- input_h, input_w = tuple(int(x) for x in input.shape[2:])
- output_h, output_w = tuple(int(x) for x in size)
- if output_h > input_h or output_w > output_h:
- if ((output_h > 1 and output_w > 1 and input_h > 1
- and input_w > 1) and (output_h - 1) % (input_h - 1)
- and (output_w - 1) % (input_w - 1)):
- warnings.warn(
- f'When align_corners={align_corners}, '
- 'the output would more aligned if '
- f'input size {(input_h, input_w)} is `x+1` and '
- f'out size {(output_h, output_w)} is `nx+1`')
- return F.interpolate(input, size, scale_factor, mode, align_corners)
-
-
- class MLP(nn.Module):
- """
- Linear Embedding
- """
-
- def __init__(self, input_dim=512, embed_dim=768):
- super().__init__()
- self.proj = nn.Linear(input_dim, embed_dim)
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2)
- x = self.proj(x)
- return x
-
-
- class conv(nn.Module):
- """
- Linear Embedding
- """
-
- def __init__(self, input_dim=512, embed_dim=768, k_s=3):
- super().__init__()
-
- self.proj = nn.Sequential(nn.Conv2d(input_dim, embed_dim, 3, padding=1, bias=False), nn.ReLU(),
- nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=False), nn.ReLU())
-
- def forward(self, x):
- x = self.proj(x)
- x = x.flatten(2).transpose(1, 2)
- return x
-
- #BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 1, kernel_size=1, bias=False)
- class BasicConv2d(nn.Module): #1111111111111111111111111111
- def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
- super(BasicConv2d, self).__init__()
-
- self.conv2 = nn.Conv2d(in_planes, out_planes,
- kernel_size=kernel_size, stride=stride,
- padding=padding, dilation=dilation, bias=False)
- self.bn = nn.BatchNorm2d(out_planes)
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- x = self.conv2(x)
- x = self.bn(x)
- x = self.relu(x)
- return x
-
-
- def Upsample(x, size, align_corners = False): #111111111111111111111111111111111111111111111111111111
- """
- Wrapper Around the Upsample Call
- """
- return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=align_corners)
-
-
- def last_zero_init(m): #11111111111111111111111111111111111111111111111
- if isinstance(m, nn.Sequential):
- constant_init(m[-1], val=0)
- else:
- constant_init(m, val=0)
-
- class ContextBlock(nn.Module): #1111111111111111111111111111111111111111111111111111111111111111
-
- def __init__(self,
- inplanes,
- ratio,
- pooling_type='avg',
- fusion_types=('channel_mul', )):
- super(ContextBlock, self).__init__()
- assert pooling_type in ['avg', 'att']
- assert isinstance(fusion_types, (list, tuple))
- valid_fusion_types = ['channel_add', 'channel_mul']
- assert all([f in valid_fusion_types for f in fusion_types])
- assert len(fusion_types) > 0, 'at least one fusion should be used'
- self.inplanes = inplanes
- self.ratio = ratio
- self.planes = int(inplanes * ratio)
- self.pooling_type = pooling_type
- self.fusion_types = fusion_types
- if pooling_type == 'att':
- self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
- self.softmax = nn.Softmax(dim=2)
- else:
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- if 'channel_add' in fusion_types:
- self.channel_add_conv = nn.Sequential(
- nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(inplace=True), # yapf: disable
- nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
- else:
- self.channel_add_conv = None
- if 'channel_mul' in fusion_types:
- self.channel_mul_conv = nn.Sequential(
- nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
- nn.LayerNorm([self.planes, 1, 1]),
- nn.ReLU(inplace=True), # yapf: disable
- nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
- else:
- self.channel_mul_conv = None
- self.reset_parameters()
-
- def reset_parameters(self):
- if self.pooling_type == 'att':
- kaiming_init(self.conv_mask, mode='fan_in')
- self.conv_mask.inited = True
-
- if self.channel_add_conv is not None:
- last_zero_init(self.channel_add_conv)
- if self.channel_mul_conv is not None:
- last_zero_init(self.channel_mul_conv)
-
- def spatial_pool(self, x):
- batch, channel, height, width = x.size()
- if self.pooling_type == 'att':
- input_x = x
- # [N, C, H * W]
- input_x = input_x.view(batch, channel, height * width)
- # [N, 1, C, H * W]
- input_x = input_x.unsqueeze(1)
- # [N, 1, H, W]
- # [N, 1, H * W]
- context_mask = context_mask.view(batch, 1, height * width)
- # [N, 1, H * W]
- context_mask = self.softmax(context_mask)
- # [N, 1, H * W, 1]
- context_mask = context_mask.unsqueeze(-1)
- # [N, 1, C, 1]
- context = torch.matmul(input_x, context_mask)
- # [N, C, 1, 1]
- context = context.view(batch, channel, 1, 1)
- else:
- # [N, C, 1, 1]
- context = self.avg_pool(x)
-
- return context
-
- def forward(self, x):
- # [N, C, 1, 1]
- context = self.spatial_pool(x)
-
- out = x
- if self.channel_mul_conv is not None:
- # [N, C, 1, 1]
- channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
- out = out + out * channel_mul_term
- if self.channel_add_conv is not None:
- # [N, C, 1, 1]
- channel_add_term = self.channel_add_conv(context)
- out = out + channel_add_term
-
- return out
-
- class ConvBranch(nn.Module):#1111111111111111111111111111111111111111111111111111111111111111111111111
- def __init__(self, in_features, hidden_features = None, out_features = None):
- super().__init__()
- hidden_features = hidden_features or in_features
- out_features = out_features or in_features
- self.conv1 = nn.Sequential(
- nn.Conv2d(in_features, hidden_features, 1, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.ReLU(inplace=True)
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.ReLU(inplace=True)
- )
- self.conv3 = nn.Sequential(
- nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.ReLU(inplace=True)
- )
- self.conv4 = nn.Sequential(
- nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.ReLU(inplace=True)
- )
- self.conv5 = nn.Sequential(
- nn.Conv2d(hidden_features, hidden_features, 1, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.SiLU(inplace=True)
- )
- self.conv6 = nn.Sequential(
- nn.Conv2d(hidden_features, hidden_features, 3, padding=1, groups=hidden_features, bias=False),
- nn.BatchNorm2d(hidden_features),
- nn.ReLU(inplace=True)
- )
- self.conv7 = nn.Sequential(
- nn.Conv2d(hidden_features, out_features, 1, bias=False),
- nn.ReLU(inplace=True)
- )
- # self.ca = ChannelAttention(64)
- # self.sa = SpatialAttention()
- self.sigmoid_spatial = nn.Sigmoid()
-
- def forward(self, x):
- res1 = x
- res2 = x
- x = self.conv1(x)
- x = x + self.conv2(x)
- x = self.conv3(x)
- x = x + self.conv4(x)
- x = self.conv5(x)
- x = x + self.conv6(x)
- x = self.conv7(x)
- x_mask = self.sigmoid_spatial(x)
- res1 = res1 * x_mask
- return res2 + res1
-
-
- class GLSA(nn.Module):#111111111111111111111111111111111111111111111111111111111
-
- def __init__(self, input_dim=512, embed_dim=32, k_s=3):
- super().__init__()
-
- self.conv1_1 = BasicConv2d(embed_dim*2,embed_dim, 1)
- self.conv1_1_1 = BasicConv2d(input_dim//2,embed_dim,1)
- self.local_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
- self.global_11conv = nn.Conv2d(input_dim//2,embed_dim,1)
- self.GlobelBlock = ContextBlock(inplanes= embed_dim, ratio=2)
- self.local = ConvBranch(in_features = embed_dim, hidden_features = embed_dim, out_features = embed_dim)
-
- def forward(self, x):
- b, c, h, w = x.size()
- x_0, x_1 = x.chunk(2,dim = 1)
-
- # local block
- local = self.local(self.local_11conv(x_0))
-
- # Globel block
- Globel = self.GlobelBlock(self.global_11conv(x_1))
-
- # concat Globel + local
- x = torch.cat([local,Globel], dim=1)
- x = self.conv1_1(x)
-
- return x
-
- class SBA(nn.Module): #11111111111111111111111111111111111111111111111111
-
- def __init__(self,input_dim = 256):
- super().__init__()
-
- self.input_dim = input_dim
-
- self.d_in1 = BasicConv2d(input_dim//2, input_dim//2, 1)
- self.d_in2 = BasicConv2d(input_dim//2, input_dim//2, 1)
-
-
- self.conv1 = nn.Sequential(BasicConv2d(input_dim, input_dim, 3,1,1), nn.Conv2d(input_dim, 16, kernel_size=1, bias=False))
- self.fc1 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
- self.fc2 = nn.Conv2d(input_dim, input_dim//2, kernel_size=1, bias=False)
-
- self.Sigmoid = nn.Sigmoid()
-
- def forward(self, H_feature, L_feature):
-
- L_feature = self.fc1(L_feature)
- H_feature = self.fc2(H_feature)
-
- g_L_feature = self.Sigmoid(L_feature)
- g_H_feature = self.Sigmoid(H_feature)
-
- L_feature = self.d_in1(L_feature)
- H_feature = self.d_in2(H_feature)
-
-
- L_feature = L_feature + L_feature * g_L_feature + (1 - g_L_feature) * Upsample(g_H_feature * H_feature, size= L_feature.size()[2:], align_corners=False)
- H_feature = H_feature + H_feature * g_H_feature + (1 - g_H_feature) * Upsample(g_L_feature * L_feature, size= H_feature.size()[2:], align_corners=False)
-
- H_feature = Upsample(H_feature, size = L_feature.size()[2:])
- # print('************************************L_feature***************************',L_feature.size())
- # print('************************************H_feature***************************',H_feature.size())
- out = self.conv1(torch.cat([H_feature,L_feature], dim=1))
- return out
-
- # class DuAT(nn.Module):
- # def __init__(self, dim=32, dims= [64, 128, 320, 512]):
- # super(DuAT, self).__init__()
-
- # self.backbone = pvt_v2_b2() # [64, 128, 320, 512]
- # path = './pretrained_pth/pvt_v2_b2.pth'
- # save_model = torch.load(path)
- # model_dict = self.backbone.state_dict()
- # state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
- # model_dict.update(state_dict)
- # self.backbone.load_state_dict(model_dict)
-
- # c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3]
-
- # self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim)
- # self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim)
- # self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim)
- # self.L_feature = BasicConv2d(c1_in_channels,dim, 3,1,1)
-
- # self.SBA = SBA(input_dim = dim)
- # self.fuse = BasicConv2d(dim * 2, dim, 1)
- # self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.Conv2d(dim, 1, kernel_size=1, bias=False))
-
-
- # def forward(self, x):
- # # backbone
- # pvt = self.backbone(x)
- # c1, c2, c3, c4 = pvt
- # n, _, h, w = c4.shape
- # _c4 = self.GLSA_c4(c4) # [1, 64, 11, 11]
- # _c4 = Upsample(_c4, c3.size()[2:]) #将c4变成c3的形状,放大
- # _c3 = self.GLSA_c3(c3) # [1, 64, 22, 22]
- # _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44]
-
- # output = self.fuse2(torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)) #将c4和c3变成c2的形状后拼接送入fuse2
-
- # L_feature = self.L_feature(c1) # [1, 64, 88, 88]
- # H_feature = self.fuse(torch.cat([_c4, _c3], dim=1))
- # H_feature = Upsample(H_feature,c2.size()[2:])
-
- # output2 = self.SBA(H_feature,L_feature)
-
- # output = F.interpolate(output, scale_factor=8, mode='bilinear')
- # output2 = F.interpolate(output2, scale_factor=4, mode='bilinear')
-
- # return output, output2
-
-
- import cv2
- import random
-
-
- class Decoder(Module):
- """
- SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
- """
-
- def __init__(self, dims, dim, class_num=2):
- super(Decoder, self).__init__()
- self.num_classes = class_num
-
- c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = dims[0], dims[1], dims[2], dims[3]
- embedding_dim = dim
-
- self.GLSA_c4 = GLSA(input_dim=c4_in_channels, embed_dim=dim)
- self.GLSA_c3 = GLSA(input_dim=c3_in_channels, embed_dim=dim)
- self.GLSA_c2 = GLSA(input_dim=c2_in_channels, embed_dim=dim)
- self.L_feature = BasicConv2d(c1_in_channels,dim, 3,1,1)
- self.SBA = SBA(input_dim = dim)
- self.fuse = BasicConv2d(dim * 2, dim, 1)
- self.fuse2 = nn.Sequential(BasicConv2d(dim*3, dim, 1,1),nn.Conv2d(dim, 1, kernel_size=1, bias=False))
- #******************************************************************************************************************************************************************
- self.linear_c4 = conv(input_dim=c4_in_channels, embed_dim=embedding_dim) #圖片(特征圖)拉直
- self.linear_c3 = conv(input_dim=c3_in_channels, embed_dim=embedding_dim)
- self.linear_c2 = conv(input_dim=c2_in_channels, embed_dim=embedding_dim)
- # self.linear_c1 = conv(input_dim=c1_in_channels, embed_dim=embedding_dim)
-
- self.linear_fuse = ConvModule(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
- self.linear_fuse34 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
- self.linear_fuse2 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
- self.linear_fuse1 = ConvModule(in_channels=embedding_dim * 2, out_channels=embedding_dim, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
- self.linear_fuses = ConvModule(in_channels=2, out_channels=1, kernel_size=1,norm_cfg=dict(type='BN', requires_grad=True))
-
- self.linear_pred = Conv2d(embedding_dim, self.num_classes, kernel_size=1) #用一维卷积和代替线性层
- self.dropout = nn.Dropout(0.1)
-
- def forward(self, inputs):
- c1, c2, c3, c4 = inputs #in
- # print('************************************c4***************************',c4.size())
- #************************************************************************************************************************************************************************
- n, _, h, w = c4.shape
- _c4d = self.GLSA_c4(c4) # [1, 64, 11, 11]
- # print('************************************_c4d***************************',_c4d.size())
- _c4d = Upsample(_c4d, c3.size()[2:]) #将c4变成c3的形状,放大
- _c3d = self.GLSA_c3(c3) # [1, 64, 22, 22]
- # _c2 = self.GLSA_c2(c2) # [1, 64, 44, 44]
-
- # output = self.fuse2(torch.cat([Upsample(_c4, c2.size()[2:]), Upsample(_c3, c2.size()[2:]), _c2], dim=1)) #将c4和c3变成c2的形状后拼接送入fuse2
-
- L_feature = self.L_feature(c1) # [1, 64, 88, 88]
-
- H_feature = self.fuse(torch.cat([_c4d, _c3d], dim=1))
- H_feature = Upsample(H_feature,c2.size()[2:])
- # print('************************************H_feature***************************',H_feature.size())
- # print('************************************L_feature***************************',L_feature.size())
- output2 = self.SBA(H_feature,L_feature)
-
- # output = F.interpolate(output, scale_factor=8, mode='bilinear')
- output2 = F.interpolate(output2, scale_factor=4, mode='bilinear')
-
- # return output, output2
-
- ############## MLP decoder on C1-C4 ###########
- # n, _, h, w = c4.shape
-
- _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
- # print('************************************_c4***************************',_c4.size())
- _c4_1 = resize(_c4, size=c3.size()[2:], mode='bilinear', align_corners=False)
- _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
- # _c3_1 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
- _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
- # _c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
- # _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
-
- L34 = self.linear_fuse34(torch.cat([_c4_1, _c3], dim=1))
- L34_1 = resize(L34, size=c2.size()[2:], mode='bilinear', align_corners=False)
- # print('************************************L34***************************',L34.size())
- L2 = self.linear_fuse2(torch.cat([L34_1, _c2], dim=1))
- L2_1 = resize(L2, size=c1.size()[2:], mode='bilinear', align_corners=False)
- output2=output2.reshape(n, -1, L2_1.shape[2], L2_1.shape[3])
- # print('************************************L2***************************',L2.size())
- # print('************************************output2***************************',output2.size())
-
- _c = self.linear_fuse1(torch.cat([L2_1, output2], dim=1))
-
-
- x = self.dropout(_c)
- x = self.linear_pred(x)
- x = self.linear_fuses(x)
-
- return x
-
-
- class mit_PLD_b2(nn.Module):
- def __init__(self, class_num=2, **kwargs):
- super(mit_PLD_b2, self).__init__()
- self.class_num = class_num
- self.backbone = mit_b2()
- self.decode_head = Decoder(dims=[64, 128, 320, 512], dim=256, class_num=class_num)
- self._init_weights() # load pretrain
-
- def forward(self, x):
- features = self.backbone(x)
-
- features = self.decode_head(features)
- up = UpsamplingBilinear2d(scale_factor=4)
- features = up(features)
- return features
- def _init_weights(self):
- pretrained_dict = torch.load('/dataset/mit_pretrained/mit_pretrained/mit_b2.pth')
- model_dict = self.backbone.state_dict()
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
- model_dict.update(pretrained_dict)
- self.backbone.load_state_dict(model_dict)
- print("successfully loaded!!!!")
|