|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- # from ...backbones import BuildActivation, BuildNormalization, constructnormcfg
-
-
- '''SelfAttentionBlock'''
- class SelfAttentionBlock(nn.Module):
- def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels, share_key_query,
- query_downsample, key_downsample, key_query_num_convs, value_out_num_convs, key_query_norm,
- value_out_norm, matmul_norm, with_out_project, norm_cfg=None, act_cfg=None):
- super(SelfAttentionBlock, self).__init__()
- # key project
- self.key_project = self.buildproject(
- in_channels=key_in_channels,
- out_channels=transform_channels,
- num_convs=key_query_num_convs,
- use_norm=key_query_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # query project
- if share_key_query:
- assert key_in_channels == query_in_channels
- self.query_project = self.key_project
- else:
- self.query_project = self.buildproject(
- in_channels=query_in_channels,
- out_channels=transform_channels,
- num_convs=key_query_num_convs,
- use_norm=key_query_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # value project
- self.value_project = self.buildproject(
- in_channels=key_in_channels,
- out_channels=transform_channels if with_out_project else out_channels,
- num_convs=value_out_num_convs,
- use_norm=value_out_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # out project
- self.out_project = None
- if with_out_project:
- self.out_project = self.buildproject(
- in_channels=transform_channels,
- out_channels=out_channels,
- num_convs=value_out_num_convs,
- use_norm=value_out_norm,
- norm_cfg=norm_cfg,
- act_cfg=act_cfg,
- )
- # downsample
- self.query_downsample = query_downsample
- self.key_downsample = key_downsample
- self.matmul_norm = matmul_norm
- self.transform_channels = transform_channels
- '''forward'''
- def forward(self, query_feats, key_feats):
- batch_size = query_feats.size(0)
- query = self.query_project(query_feats)
- if self.query_downsample is not None: query = self.query_downsample(query)
- query = query.reshape(*query.shape[:2], -1)
- query = query.permute(0, 2, 1).contiguous()
- key = self.key_project(key_feats)
- value = self.value_project(key_feats)
- if self.key_downsample is not None:
- key = self.key_downsample(key)
- value = self.key_downsample(value)
- key = key.reshape(*key.shape[:2], -1)
- value = value.reshape(*value.shape[:2], -1)
- value = value.permute(0, 2, 1).contiguous()
- sim_map = torch.matmul(query, key)
- if self.matmul_norm:
- sim_map = (self.transform_channels ** -0.5) * sim_map
- sim_map = F.softmax(sim_map, dim=-1)
- context = torch.matmul(sim_map, value)
- context = context.permute(0, 2, 1).contiguous()
- context = context.reshape(batch_size, -1, *query_feats.shape[2:])
- if self.out_project is not None:
- context = self.out_project(context)
- return context
- '''build project'''
- def buildproject(self, in_channels, out_channels, num_convs, use_norm, norm_cfg, act_cfg):
- if use_norm:
- convs = [nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- # BuildNormalization(constructnormcfg(placeholder=out_channels, norm_cfg=norm_cfg)),
- # BuildActivation(act_cfg),
- )]
- for _ in range(num_convs - 1):
- convs.append(nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- # BuildNormalization(constructnormcfg(placeholder=out_channels, norm_cfg=norm_cfg)),
- # BuildActivation(act_cfg),
- ))
- else:
- convs = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)]
- for _ in range(num_convs - 1):
- convs.append(
- nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
- )
- if len(convs) > 1: return nn.Sequential(*convs)
- return convs[0]
-
-
- class DownSample(nn.Module):
- def __init__(self, in_channel, out_channel, stride):
- super(DownSample, self).__init__()
- self.down = nn.Sequential(
- nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False),
- nn.BatchNorm2d(out_channel),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, x):
- out = self.down(x)
- return out
-
-
- class ResNet101(nn.Module):
- def __init__(self, classes_num): # 指定分类数
- super(ResNet101, self).__init__()
- self.pre = nn.Sequential(
- nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- )
- # --------------------------------------------------------------------
- self.layer1_first = nn.Sequential(
- nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- self.layer1_next = nn.Sequential(
- nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64),
- nn.ReLU(inplace=True),
- nn.Conv2d(64, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256)
- )
- # --------------------------------------------------------------------
- self.layer2_first = nn.Sequential(
- nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- self.layer2_next = nn.Sequential(
- nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(128),
- nn.ReLU(inplace=True),
- nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512)
- )
- # --------------------------------------------------------------------
- self.layer3_first = nn.Sequential(
- nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
- self.layer3_next = nn.Sequential(
- nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(256),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 1024, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(1024)
- )
- # --------------------------------------------------------------------
- self.layer4_first = nn.Sequential(
- nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 2048, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(2048)
- )
- self.layer4_next = nn.Sequential(
- nn.Conv2d(2048, 512, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=True),
- nn.Conv2d(512, 2048, kernel_size=1, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(2048)
- )
- # --------------------------------------------------------------------
- self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Sequential(
- nn.Dropout(p=0.5),
- nn.Linear(2048 * 1 * 1, 1000),
- nn.ReLU(inplace=True),
- nn.Dropout(p=0.5),
- nn.Linear(1000, classes_num)
- )
-
- def forward(self, x):
- out = self.pre(x)
- # --------------------------------------------------------------------
- layer1_shortcut = DownSample(64, 256, 1)
- layer1_shortcut.to('cuda:0')
- layer1_identity = layer1_shortcut(out)
- out = self.layer1_first(out)
- out = F.relu(out + layer1_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer1_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer2_shortcut = DownSample(256, 512, 2)
- layer2_shortcut.to('cuda:0')
- layer2_identity = layer2_shortcut(out)
- out = self.layer2_first(out)
- out = F.relu(out + layer2_identity, inplace=True)
-
- for i in range(3):
- identity = out
- out = self.layer2_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer3_shortcut = DownSample(512, 1024, 2)
- layer3_shortcut.to('cuda:0')
- layer3_identity = layer3_shortcut(out)
- out = self.layer3_first(out)
- out = F.relu(out + layer3_identity, inplace=True)
-
- for i in range(22):
- identity = out
- out = self.layer3_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- layer4_shortcut = DownSample(1024, 2048, 2)
- layer4_shortcut.to('cuda:0')
- layer4_identity = layer4_shortcut(out)
- out = self.layer4_first(out)
- out = F.relu(out + layer4_identity, inplace=True)
-
- for i in range(2):
- identity = out
- out = self.layer4_next(out)
- out = F.relu(out + identity, inplace=True)
- # --------------------------------------------------------------------
- out = self.avg_pool(out)
- out = out.reshape(out.size(0), -1)
- out = self.fc(out)
-
- return out
-
- class ASPPConv(nn.Sequential):
- def __init__(self, in_channels, out_channels, dilation):
- modules = [
- nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU()
- ]
- super(ASPPConv, self).__init__(*modules)
-
- class ASPPPooling(nn.Sequential):
- def __init__(self, in_channels, out_channels):
- super(ASPPPooling, self).__init__(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU())
-
- def forward(self, x):
- size = x.shape[-2:]
- if x.shape[0] != 1:
- for mod in self:
- x = mod(x)
- return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
-
- class ASPP(nn.Module):
- def __init__(self, in_channels, atrous_rates, out_channels):
- super(ASPP, self).__init__()
- modules = []
- modules.append(nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU()))
-
- rates = tuple(atrous_rates)
- for rate in rates:
- modules.append(ASPPConv(in_channels, out_channels, rate))
-
- modules.append(ASPPPooling(in_channels, out_channels))
-
- self.convs = nn.ModuleList(modules)
-
- self.project = nn.Sequential(
- nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(),
- nn.Dropout(0.5))
-
- def forward(self, x):
- res = []
- for conv in self.convs:
- res.append(conv(x))
- res = torch.cat(res, dim=1)
- return self.project(res)
-
- class SpatialTransformer(nn.Module):
- def __init__(self, spatial_dims):
- super(SpatialTransformer, self).__init__()
- self._h, self._w = spatial_dims
- self.fc1 = nn.Linear(32*4*4, 1024) # 可根据自己的网络参数具体设置
- self.fc2 = nn.Linear(1024, 6)
-
- def forward(self, x):
- batch_images = x #保存一份原始数据
- x = x.view(-1, 32*4*4)
- # 利用FC结构学习到6个参数
- x = self.fc1(x)
- x = self.fc2(x)
- x = x.view(-1, 2,3) # 2x3
- # 利用affine_grid生成采样点
- affine_grid_points = F.affine_grid(x, torch.Size((x.size(0), self._in_ch, self._h, self._w)))
- # 将采样点作用到原始数据上
- rois = F.grid_sample(batch_images, affine_grid_points)
- return rois, affine_grid_points
|