|
- '''
- # @time:2023/3/1 9:45
- # Author:Tuan
- # @File:u_r_m_l.py
- '''
- '''
- '''
- import torch
- import torch.nn as nn
- from torch.nn import functional as F
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
- class Self_Attn(nn.Module):
- """ Self attention Layer"""
-
- def __init__(self, in_dim):
- super(Self_Attn, self).__init__()
- self.chanel_in = in_dim
-
- self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
- self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
- self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
- self.gamma = nn.Parameter(torch.zeros(1))
-
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x):
- """
- inputs :
- x : input feature maps( B * C * W * H)
- returns :
- out : self attention value + input feature
- attention: B * N * N (N is Width*Height)
- """
- m_batchsize, C, width, height = x.size()
- proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B*N*C
- proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B*C*N
- energy = torch.bmm(proj_query, proj_key) # batch的matmul B*N*N
- attention = self.softmax(energy) # B * (N) * (N)
- proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B * C * N
-
- out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # B*C*N
- out = out.view(m_batchsize, C, width, height) # B*C*H*W
-
- out = self.gamma * out + x
- return out
-
- class Lin_conv(nn.Module):
- def __init__(self,in_channels, num_classes):
- super(Lin_conv, self).__init__()
- self.layer1 = nn.Sequential(
- nn.Conv2d(in_channels, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, in_channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(in_channels)
- )
- self.final_conv = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1)
- def forward(self, x):
- identity = x
- out = self.layer1(x)
- out = F.relu(out + identity, inplace=True)
- out = self.final_conv(out)
-
- return out
-
- class ResNet101(nn.Module):
- def __init__(self, classes_num): # 指定分类数
- super(ResNet101, self).__init__()
- self.pre = nn.Sequential(
- nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- )
- # --------------------------------------------------------------------
- self.layer1_first = nn.Sequential(
- nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- self.layer1_next = nn.Sequential(
- nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- # --------------------------------------------------------------------
- self.layer2_first = nn.Sequential(
- nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- self.layer2_next = nn.Sequential(
- nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- # --------------------------------------------------------------------
- self.layer3_first = nn.Sequential(
- nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
- self.layer3_next = nn.Sequential(
- nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
- # --------------------------------------------------------------------
- self.layer4_first = nn.Sequential(
- nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 2048, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(2048)
- )
- self.layer4_next = nn.Sequential(
- nn.Conv2d(2048, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 2048, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(2048)
- )
- # --------------------------------------------------------------------
- self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Sequential(
- nn.Dropout(p=0.5),
- nn.Linear(2048 * 1 * 1, 1000),
- nn.ReLU(inplace=True),
- nn.Dropout(p=0.5),
- nn.Linear(1000, classes_num)
- )
-
- def forward(self, x):
- out = self.pre(x)
- # --------------------------------------------------------------------
- layer1_shortcut = DownSample(64, 256, 1)
- layer1_shortcut.to('cuda:0')
- layer1_identity = layer1_shortcut(out)
- out = self.layer1_first(out)
- out = F.relu(out + layer1_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer1_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer2_shortcut = DownSample(256, 512, 2)
- layer2_shortcut.to('cuda:0')
- layer2_identity = layer2_shortcut(out)
- out = self.layer2_first(out)
- out = F.relu(out + layer2_identity, inplace=True)
-
- for i in range(3):
- identity = out
- out = self.layer2_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer3_shortcut = DownSample(512, 1024, 2)
- layer3_shortcut.to('cuda:0')
- layer3_identity = layer3_shortcut(out)
- out = self.layer3_first(out)
- out = F.relu(out + layer3_identity, inplace=True)
-
- for i in range(22):
- identity = out
- out = self.layer3_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer4_shortcut = DownSample(1024, 2048, 2)
- layer4_shortcut.to('cuda:0')
- layer4_identity = layer4_shortcut(out)
- out = self.layer4_first(out)
- out = F.relu(out + layer4_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer4_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- out = self.avg_pool(out)
- out = out.reshape(out.size(0), -1)
- out = self.fc(out)
-
- return out
-
- class ASPPConv(nn.Sequential):
- def __init__(self, in_channels, out_channels, dilation):
- modules = [
- nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU()
- ]
- super(ASPPConv, self).__init__(*modules)
-
- class ASPPPooling(nn.Sequential):
- def __init__(self, in_channels, out_channels):
- super(ASPPPooling, self).__init__(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU())
-
- def forward(self, x):
- size = x.shape[-2:]
- if x.shape[0] != 1:
- for mod in self:
- x = mod(x)
- return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
-
- class ASPP(nn.Module):
- def __init__(self, in_channels, atrous_rates, out_channels):
- super(ASPP, self).__init__()
- modules = []
- modules.append(nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU()))
-
- rates = tuple(atrous_rates)
- for rate in rates:
- modules.append(ASPPConv(in_channels, out_channels, rate))
-
- modules.append(ASPPPooling(in_channels, out_channels))
-
- self.convs = nn.ModuleList(modules)
-
- self.project = nn.Sequential(
- nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- nn.Dropout(0.5))
-
- def forward(self, x):
- res = []
- for conv in self.convs:
- res.append(conv(x))
- res = torch.cat(res, dim=1)
- return self.project(res)
-
- class DownSample(nn.Module):
- def __init__(self, in_channel, out_channel, stride):
- super(DownSample, self).__init__()
- self.down = nn.Sequential(
- nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
- nn.BatchNorm2d(out_channel),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, x):
- out = self.down(x)
- return out
-
- class SelfAttentionBlock(nn.Module):
- def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels, share_key_query,
- query_downsample, key_downsample, key_query_num_convs, value_out_num_convs, key_query_norm,
- value_out_norm, matmul_norm, with_out_project, norm_cfg=None, act_cfg=None):
- super(SelfAttentionBlock, self).__init__()
- # key project
- self.key_project = self.buildproject(
- in_channels=key_in_channels,
- out_channels=transform_channels,
- num_convs=key_query_num_convs,
- use_norm=key_query_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # query project
- if share_key_query:
- assert key_in_channels == query_in_channels
- self.query_project = self.key_project
- else:
- self.query_project = self.buildproject(
- in_channels=query_in_channels,
- out_channels=transform_channels,
- num_convs=key_query_num_convs,
- use_norm=key_query_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # value project
- self.value_project = self.buildproject(
- in_channels=key_in_channels,
- out_channels=transform_channels if with_out_project else out_channels,
- num_convs=value_out_num_convs,
- use_norm=value_out_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # out project
- self.out_project = None
- if with_out_project:
- self.out_project = self.buildproject(
- in_channels=transform_channels,
- out_channels=out_channels,
- num_convs=value_out_num_convs,
- use_norm=value_out_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # downsample
- self.query_downsample = query_downsample
- self.key_downsample = key_downsample
- self.matmul_norm = matmul_norm
- self.transform_channels = transform_channels
- '''forward'''
- def forward(self, query_feats, key_feats):
- batch_size = query_feats.size(0)
- query = self.query_project(query_feats)
- if self.query_downsample is not None: query = self.query_downsample(query)
- query = query.reshape(*query.shape[:2], -1)
- query = query.permute(0, 2, 1).contiguous()
- key = self.key_project(key_feats)
- value = self.value_project(key_feats)
- if self.key_downsample is not None:
- key = self.key_downsample(key)
- value = self.key_downsample(value)
- key = key.reshape(*key.shape[:2], -1)
- value = value.reshape(*value.shape[:2], -1)
- value = value.permute(0, 2, 1).contiguous()
- sim_map = torch.matmul(query, key)
- if self.matmul_norm:
- sim_map = (self.transform_channels ** -0.5) * sim_map
- sim_map = F.softmax(sim_map, dim=-1)
- context = torch.matmul(sim_map, value)
- context = context.permute(0, 2, 1).contiguous()
- context = context.reshape(batch_size, -1, *query_feats.shape[2:])
- if self.out_project is not None:
- context = self.out_project(context)
- return context
- '''build project'''
- def buildproject(self, in_channels, out_channels, num_convs, use_norm, norm_cfg, act_cfg):
- if use_norm:
- convs = [nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- # BuildNormalization(constructnormcfg(placeholder=out_channels, norm_cfg=norm_cfg)),
- # BuildActivation(act_cfg),
- )]
- for _ in range(num_convs - 1):
- convs.append(nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- # BuildNormalization(constructnormcfg(placeholder=out_channels, norm_cfg=norm_cfg)),
- # BuildActivation(act_cfg),
- ))
- else:
- convs = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)]
- for _ in range(num_convs - 1):
- convs.append(
- nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
- )
- if len(convs) > 1: return nn.Sequential(*convs)
- return convs[0]
-
- class Memoy(nn.Module):
- def __init__(self,num_classes,num_feats_per_cls):
- super(Memoy, self).__init__()
- self.num_feats_per_cls = num_feats_per_cls
- self.memory = nn.Parameter(torch.zeros(num_classes, self.num_feats_per_cls, 1024, dtype=torch.float),
- requires_grad=False)
- self.feats_channels = 1024
- if self.num_feats_per_cls > 1:
- self.self_attentions = nn.ModuleList()
- for _ in range(self.num_feats_per_cls):
- self_attention = SelfAttentionBlock(
- key_in_channels=1024,
- query_in_channels=1024,
- transform_channels=512,
- out_channels=1024,
- share_key_query=False,
- query_downsample=None,
- key_downsample=None,
- key_query_num_convs=2,
- value_out_num_convs=1,
- key_query_norm=True,
- value_out_norm=True,
- matmul_norm=True,
- with_out_project=True,
-
- )
- self.self_attentions.append(self_attention)
- self.fuse_memory_conv = nn.Sequential(
- nn.Conv2d(1024 * self.num_feats_per_cls, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024),
- nn.ReLU(),
- )
- else:
- self.self_attention = SelfAttentionBlock(
- key_in_channels=1024,
- query_in_channels=1024,
- transform_channels=512,
- out_channels=1024,
- share_key_query=False,
- query_downsample=None,
- key_downsample=None,
- key_query_num_convs=2,
- value_out_num_convs=1,
- key_query_norm=True,
- value_out_norm=True,
- matmul_norm=True,
- with_out_project=True,
-
- )
- self.bottleneck = nn.Sequential(
- nn.Conv2d(1024 * 2, 1024, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(1024),
- nn.ReLU(),
- )
- self.up = Up_sample()
- def forward(self,x ,):
- batch_size, num_channels, h, w = x.size()
- # h, w = 256, 256
- x_m, x_m_1 = self.up(x) # [1, 7,16 ,16 ]
- x_m = x_m.reshape(x_m.shape[0] * x_m.shape[1], x_m.shape[2] * x_m.shape[3])
- x_m = x_m.permute(1, 0)
- selected_memory_list = []
- for idx in range(self.num_feats_per_cls):
- memory = self.memory.data[:, idx, :]
- selected_memory = torch.matmul(x_m, memory)
- selected_memory_list.append(selected_memory.unsqueeze(1))
- if self.num_feats_per_cls > 1:
- relation_selected_memory_list = []
- for idx, selected_memory in enumerate(selected_memory_list):
- # --(B*H*W, C) --> (B, H, W, C)
- selected_memory = selected_memory.view(batch_size, h, w, num_channels)
- # --(B, H, W, C) --> (B, C, H, W)
- selected_memory = selected_memory.permute(0, 3, 1, 2).contiguous()
- # --append
- relation_selected_memory_list.append(self.self_attentions[idx](x, selected_memory))
- # --concat
- selected_memory = torch.cat(relation_selected_memory_list, dim=1)
- selected_memory = self.fuse_memory_conv(selected_memory)
- else:
- assert len(selected_memory_list) == 1
- selected_memory = selected_memory_list[0].squeeze(1)
- # --(B*H*W, C) --> (B, H, W, C)
- selected_memory = selected_memory.view(batch_size, h, w, num_channels)
- # --(B, H, W, C) --> (B, C, H, W)
- selected_memory = selected_memory.permute(0, 3, 1, 2).contiguous()
- # --feed into the self attention module
- selected_memory = self.self_attention(x, selected_memory)
- memory_output = self.bottleneck(torch.cat([x, selected_memory], dim=1))
-
- return memory_output
- def update(self, features, segmentation, ignore_index=255, momentum=0.9, learning_rate=None):
- # assert strategy in ['mean', 'cosine_similarity']
- batch_size, num_channels, h, w = features.size()
- # momentum = momentum_cfg['base_momentum']
- # if momentum_cfg['adjust_by_learning_rate']:
- # momentum = momentum_cfg['base_momentum'] / momentum_cfg['base_lr'] * learning_rate
- # use features to update memory
-
- # momentum = 0.9
- segmentation = segmentation.long()
- features = features.permute(0, 2, 3, 1).contiguous()
- features = features.view(batch_size * h * w, num_channels)
- clsids = segmentation.unique()
- for clsid in clsids:
- if clsid == ignore_index: continue
- # --(B, H, W) --> (B*H*W,)
- seg_cls = segmentation.view(-1)
- # --extract the corresponding feats: (K, C)
- feats_cls = features[seg_cls == clsid]
- # --init memory by using extracted features
- need_update = True
- for idx in range(self.num_feats_per_cls):
- if (self.memory[clsid][idx] == 0).sum() == self.feats_channels:
- self.memory[clsid][idx].data.copy_(feats_cls.mean(0))
- need_update = False
- break
- if not need_update: continue
- # --update according to the selected strategy
- if self.num_feats_per_cls == 1:
- # if strategy == 'mean':
- feats_cls = feats_cls.mean(0)
- # elif strategy == 'cosine_similarity':
- # similarity = F.cosine_similarity(feats_cls, self.memory[clsid].data.expand_as(feats_cls))
- # weight = (1 - similarity) / (1 - similarity).sum()
- # feats_cls = (feats_cls * weight.unsqueeze(-1)).sum(0)
- feats_cls = (1 - momentum) * self.memory[clsid].data + momentum * feats_cls.unsqueeze(0)
- self.memory[clsid].data.copy_(feats_cls)
- # print("成功更新!")
- # else:
- # assert strategy in ['cosine_similarity']
- # # ----(K, C) * (C, num_feats_per_cls) --> (K, num_feats_per_cls)
- # relation = torch.matmul(
- # F.normalize(feats_cls, p=2, dim=1),
- # F.normalize(self.memory[clsid].data.permute(1, 0).contiguous(), p=2, dim=0),
- # )
- # argmax = relation.argmax(dim=1)
- # # ----for saving memory during training
- # for idx in range(self.num_feats_per_cls):
- # mask = (argmax == idx)
- # feats_cls_iter = feats_cls[mask]
- # memory_cls_iter = self.memory[clsid].data[idx].unsqueeze(0).expand_as(feats_cls_iter)
- # similarity = F.cosine_similarity(feats_cls_iter, memory_cls_iter)
- # weight = (1 - similarity) / (1 - similarity).sum()
- # feats_cls_iter = (feats_cls_iter * weight.unsqueeze(-1)).sum(0)
- # self.memory[clsid].data[idx].copy_(self.memory[clsid].data[idx] * (1 - momentum) + feats_cls_iter * momentum)
-
- class Up_sample(nn.Module):
- def __init__(self):
- super(Up_sample, self).__init__()
-
- self.UP_stage_1 = nn.Sequential(
- nn.Conv2d(1024, 512, 3, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- # nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_2 = nn.Sequential(
- nn.Conv2d(512, 256, 3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU(),
- # nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_3 = nn.Sequential(
- nn.Conv2d(256, 128, 3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(),
- # nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_4 = nn.Sequential(
- nn.Conv2d(128, 7, 3, padding=1),
- nn.BatchNorm2d(7),
- nn.ReLU(),
- # nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
-
- self.UP = nn.Sequential(
- # nn.Conv2d(1024, 512, 3, padding=1),
- # nn.BatchNorm2d(512),
- # nn.ReLU(),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
-
-
-
-
- def forward(self,x):
- x1 = x
- x = self.UP_stage_1(x)
- x = self.UP_stage_2(x)
- x = self.UP_stage_3(x)
- x = self.UP_stage_4(x)
- # x = self.cls_seg(x)
- x1 = self.UP(x1)
- x1 = self.UP(x1)
- x1 = self.UP(x1)
- x1 = self.UP(x1)
- return x , x1
-
- class Up_sample_M(nn.Module):
- def __init__(self):
- super(Up_sample_M, self).__init__()
-
- self.UP_stage_1 = nn.Sequential(
- nn.Conv2d(1024, 512, 3, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_2 = nn.Sequential(
- nn.Conv2d(512, 256, 3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU(),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_3 = nn.Sequential(
- nn.Conv2d(256, 128, 3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
- self.UP_stage_4 = nn.Sequential(
- nn.Conv2d(128, 32, 3, padding=1),
- nn.BatchNorm2d(32),
- nn.ReLU(),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- )
-
-
-
-
-
- def forward(self,x):
- x = self.UP_stage_1(x)
- x = self.UP_stage_2(x)
- x = self.UP_stage_3(x)
- x = self.UP_stage_4(x)
-
- return x
- def x2conv(in_channels, out_channels, inner_channels=None):
- inner_channels = out_channels // 2 if inner_channels is None else inner_channels
- down_conv = nn.Sequential(
- nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(inner_channels),
- nn.GELU(),
- nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.GELU()
- )
- return down_conv
-
-
- class Encoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(Encoder, self).__init__()
- self.down_conv = x2conv(in_channels, out_channels)
- self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
-
- def forward(self, x):
- x = self.down_conv(x)
- x = self.pool(x)
- return x
-
- class Decoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(Decoder, self).__init__()
- self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
- self.up_conv = x2conv(in_channels, out_channels)
-
- def forward(self, x_copy, x, interpolate=True):
- x = self.up(x)
-
- if (x.size(2) != x_copy.size(2) or x.size(3) != x_copy.size(3)):
- if interpolate:
- x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)), mode="bilinear", align_corners=True)
- else:
- diffy = x_copy.size()[2] - x.size()[2]
- diffx = x_copy.size()[3] - x.size()[3]
- x = F.pad(x, (diffx//2, diffx - diffx//2),
- diffy//2, diffy - diffy//2)
-
- x = torch.cat([x_copy, x], dim=1)
- x = self.up_conv(x)
- return x
-
- class UNet(nn.Module):
- def __init__(self, num_classes,freeze_bn=False): # 指定分类数
- super(UNet, self).__init__()
- self.pre = nn.Sequential(
- nn.Conv2d(64, 64, kernel_size=7, stride=2, padding=3, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- # nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- )
- # --------------------------------------------------------------------
- self.layer1_first = nn.Sequential(
- nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128)
- )
- self.layer1_next = nn.Sequential(
- nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128)
- )
- # --------------------------------------------------------------------
- self.layer2_first = nn.Sequential(
- nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- self.layer2_next = nn.Sequential(
- nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- # --------------------------------------------------------------------
- self.layer3_first = nn.Sequential(
- nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- self.layer3_next = nn.Sequential(
- nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- # --------------------------------------------------------------------
- self.layer4_first = nn.Sequential(
- nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
- self.layer4_next = nn.Sequential(
- nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
-
- self.layer1_shortcut = DownSample(64, 128, 1)
- self.layer2_shortcut = DownSample(128, 256, 2)
- self.layer3_shortcut = DownSample(256, 512, 2)
- self.layer4_shortcut = DownSample(512, 1024, 2)
-
-
- # --------------------------------------------------------------------
- self.start_conv = x2conv(4, 64)
- # self.down1 = Encoder(64, 128)
- # self.down2 = Encoder(128, 256)
- # self.down3 = Encoder(256, 512)
- # self.down4 = Encoder(512, 1024)
-
- # self.middle_conv = x2conv(2048, 2048)
-
- self.up1 = Decoder(1024, 512)
- self.up2 = Decoder(512, 256)
- self.up3 = Decoder(256, 128)
- self.up4 = Decoder(128, 64)
- self.final_conv = nn.Conv2d(96, num_classes, kernel_size=1)
- self.softmax = nn.Softmax(dim=1)
- self._initialize_weights()
-
- self.dropout = nn.Dropout(p=0.5)
-
- #ASPP
- self.aspp1 = ASPP(512, [6, 12, 18], 512)
- self.aspp2 = ASPP(256, [6, 12, 18], 256)
- self.aspp3 = ASPP(128, [6, 12, 18], 128)
- self.aspp4 = ASPP(64, [6, 12, 18], 64)
-
- #多输出
- self.outprocess1 = nn.Sequential(
- nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1, bias=False),
- nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True)
- # nn.BatchNorm2d(in_channels)
- )
- self.outprocess2 = nn.Sequential(
- nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1, bias=False),
- nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
- # nn.BatchNorm2d(in_channels)
- )
- self.outprocess3 = nn.Sequential(
- nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1, bias=False),
- nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- # nn.BatchNorm2d(in_channels)
- )
-
- #注意力模块
- self.atten1 = Self_Attn(1024)
- self.atten2 = Self_Attn(512)
- self.atten3 = Self_Attn(256)
- self.atten4 = Self_Attn(128)
- #Memory
- self.memory = Memoy(num_classes, num_feats_per_cls=1)
- self.up = Up_sample_M()
-
- # line
- # self.line_up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
- self.line_conv = Lin_conv(7,7)
-
-
-
- if freeze_bn:
- self.freeze_bn()
- def _initialize_weights(self):
- num = 0
- for m in self.modules():
- num +=1
- # print(m)
- if isinstance(m, nn.Linear):
- nn.init.xavier_uniform_(m.weight, gain=1)
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- print(f"depth: {num}")
-
-
- def forward(self, x):
- out = self.start_conv(x)
- x1 = out
- out = self.pre(out) # 4X down
-
- # --------------------------------------------------------------------
- # layer1_shortcut = DownSample(64, 128, 1)
- # layer1_shortcut
- layer1_identity = self.layer1_shortcut(out)# 256,128,128
- out = self.layer1_first(out) #channel -> 256
- out = F.relu(out + layer1_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer1_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- x2 = out
- # layer2_shortcut = DownSample(128, 256, 2)
- # layer2_shortcut
- layer2_identity = self.layer2_shortcut(out) #2x down 512 64 64
- out = self.layer2_first(out)
- out = F.relu(out + layer2_identity, inplace=True)
-
- for i in range(3):
- identity = out
- out = self.layer2_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- x3 = out
-
- # layer3_shortcut
- layer3_identity = self.layer3_shortcut(out) # 1024 32 31
- out = self.layer3_first(out) #2x down
- out = F.relu(out + layer3_identity, inplace=True)
-
- for i in range(22):
- identity = out
- out = self.layer3_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- x4 = out
-
- # layer4_shortcut
- layer4_identity = self.layer4_shortcut(out) #2048 16 16
- out = self.layer4_first(out) #2X down
- out = F.relu(out + layer4_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer4_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- x5 = out
-
- memory_input = x5
- x_m = self.memory.forward(memory_input)
- x_m = self.up(x_m)
-
- # x = self.middle_conv(x4)
- out_seg = []
- x5 = self.atten1(x5)
- x = self.up1(self.aspp1(x4), x5)# 512,32,32
- out1 = self.outprocess1(x)
- out_seg.append(out1)
- x = self.atten2(x)
- x = self.up2(self.aspp2(x3), x) #256 64 64
- out2 = self.outprocess2(x)
- out_seg.append(out2)
- x = self.atten3(x)
- x = self.up3(self.aspp3(x2), x)# 128 128 128
- out3 = self.outprocess3(x)
- out_seg.append(out3)
- x = self.atten4(x)
- x = self.up4(self.aspp4(x1), x)# 64 256 256
-
- x = torch.cat([x, x_m], dim=1)
-
- x = self.final_conv(x)
- out_seg.append(x)
- #
- # # sea-land
- # x_sea = x[:, 6:7, :, :]
- # x_sl = torch.cat([x_sea,torch.zeros(x_sea.shape).to(device)], dim=1)
- #
- #
- # # line
- # # x_line = self.line_up(x)
- # x_line = self.line_conv(x)
-
- return out_seg, memory_input
- def update(self, features, lab,m):
- self.memory.update(features,lab,momentum=m)
-
- # input = torch.rand(1,4,256,256)
- #
- # model = Unet(num_classes=7)
- #
- # output = model(input)
- #
- # print(output.shape)
-
-
-
|