|
- """
- This script defines the structure of A hybrid nested UNet combined with dilation dense net and nested unet
-
- Author: He Hongliang
-
- 20200306
-
- """
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import math
-
-
- # 统一尺寸,保证不会因为出现单数下采样再上采样后变成双数
- def resizeScale(x_copy, x):
- diffY = x_copy.size()[2] - x.size()[2]
- diffX = x_copy.size()[3] - x.size()[3]
- if (diffY != 0 or diffX != 0):
- x = F.pad(x, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
- return x
-
-
- class sSE(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
- self.norm = nn.Sigmoid()
-
- def forward(self, U):
- q = self.Conv1x1(U) # U:[bs,c,h,w] to q:[bs,1,h,w]
- q = self.norm(q)
- return U * q # 广播机制
-
-
- class cSE(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.avgpool = nn.AdaptiveAvgPool2d(1)
- self.Conv_Squeeze = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, bias=False)
- self.Conv_Excitation = nn.Conv2d(in_channels // 2, in_channels, kernel_size=1, bias=False)
- self.norm = nn.Sigmoid()
-
- def forward(self, U):
- z = self.avgpool(U) # shape: [bs, c, h, w] to [bs, c, 1, 1]
- z = self.Conv_Squeeze(z) # shape: [bs, c/2]
- z = self.Conv_Excitation(z) # shape: [bs, c]
- z = self.norm(z)
- return U * z.expand_as(U)
-
-
- class csSE(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.cSE = cSE(in_channels)
- self.sSE = sSE(in_channels)
-
- def forward(self, U):
- U_sse = self.sSE(U)
- U_cse = self.cSE(U)
- return U_cse + U_sse
-
-
- # Series attention module SAM 20200311
- class SAM(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.cSE = cSE(in_channels)
- self.sSE = sSE(in_channels)
-
- def forward(self, U):
- U_cse = self.cSE(U)
- U_sse = self.sSE(U_cse)
- return U_sse
-
- # Parallel attention module PAM = csSE 20200311
-
-
- class PAM(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.cSE = cSE(in_channels)
- self.sSE = sSE(in_channels)
-
- def forward(self, U):
- U_sse = self.sSE(U)
- U_cse = self.cSE(U)
- return U_cse + U_sse
-
-
- class encoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(encoder, self).__init__()
- self.down_conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
- self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
-
- def forward(self, x):
- x = self.down_conv(x)
- x_pooled = self.pool(x)
- return x, x_pooled
-
-
- # nn.Upsample(scale_factor=2)
- class decoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(decoder, self).__init__()
- self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
- self.up_conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x_copy, x):
- x = self.up(x)
- # Padding in case the incomping volumes are of different sizes
- diffY = x_copy.size()[2] - x.size()[2]
- diffX = x_copy.size()[3] - x.size()[3]
- x = F.pad(x, (diffX // 2, diffX - diffX // 2,
- diffY // 2, diffY - diffY // 2))
- # Concatenate
- x = torch.cat([x_copy, x], dim=1)
- x = self.up_conv(x)
- return x
-
-
- class conv_block(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(conv_block, self).__init__()
- self.convblock = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.convblock(x)
- return x
-
-
- # 只使用 Channel Attention 且放在两个conv的后面
- class attetion_conv_block_Channel(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Channel, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.cSE = cSE(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.conv_block2(x)
- x = self.cSE(x)
-
- return x
- # 只使用 Spatial Attention 且放在两个conv的后面
-
-
- class attetion_conv_block_Spatial(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Spatial, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.sSE = sSE(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.conv_block2(x)
- x = self.sSE(x)
-
- return x
-
- # 只使用 Series attention module 且放在两个conv的中间
-
-
- class attetion_conv_block_Series_V1(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Series_V1, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.cSE = cSE(out_channels)
- self.sSE = sSE(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.cSE(x)
- x = self.sSE(x)
- x = self.conv_block2(x)
-
- return x
-
-
- # 只使用 Series attention module 且放在两个conv的后面
- class attetion_conv_block_Series_V2(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Series_V2, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.cSE = cSE(out_channels)
- self.sSE = sSE(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.conv_block2(x)
- x = self.cSE(x)
- x = self.sSE(x)
-
- return x
-
-
- # 只使用 Parallel attention module 且放在两个conv的中间
- class attetion_conv_block_Parallel_V1(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Parallel_V1, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
- self.PAM = PAM(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.PAM(x)
- x = self.conv_block2(x)
-
- return x
-
-
- # 只使用 Parallel attention module 且放在两个conv的后面
- class attetion_conv_block_Parallel_V2(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_Parallel_V2, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
- self.PAM = PAM(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.conv_block2(x)
- x = self.PAM(x)
-
- return x
-
- # 使用 Series attention module + Parallel attention module
-
-
- class attetion_conv_block(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.cSE = cSE(out_channels)
- self.sSE = sSE(out_channels)
- self.PAM = PAM(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.cSE(x)
- x = self.sSE(x)
- x = self.conv_block2(x)
- x = self.PAM(x)
-
- return x
-
- # 使用 Parallel attention module + Series attention module
-
-
- class attetion_conv_block_V2(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_V2, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.cSE = cSE(out_channels)
- self.sSE = sSE(out_channels)
- self.PAM = PAM(out_channels)
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x):
- x = self.conv_block1(x)
- x = self.PAM(x)
- x = self.conv_block2(x)
- x = self.cSE(x)
- x = self.sSE(x)
-
- return x
-
- # 使用 Series attention module + Parallel attention module 20200314NEW
-
-
- class attetion_conv_block_V3(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(attetion_conv_block_V3, self).__init__()
-
- self.conv_block1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.conv_block2 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- self.SAM = SAM(out_channels)
- self.PAM = PAM(out_channels)
-
- def forward(self, x):
- x1 = self.conv_block1(x)
- x1 = self.SAM(x1)
- x2 = self.conv_block2(x)
- x2 = self.PAM(x2)
-
- return (x1 + x2) / 2
-
-
- class ConvLayer(nn.Sequential):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
- groups=1):
- super(ConvLayer, self).__init__()
- self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
- padding=padding, dilation=dilation, bias=False, groups=groups))
- self.add_module('relu', nn.LeakyReLU(inplace=True))
- self.add_module('bn', nn.BatchNorm2d(out_channels))
-
-
- # --- different types of layers --- #
- class BasicLayer(nn.Sequential):
- def __init__(self, in_channels, growth_rate, drop_rate, dilation=1):
- super(BasicLayer, self).__init__()
- self.conv = ConvLayer(in_channels, growth_rate, kernel_size=3, stride=1, padding=dilation,
- dilation=dilation)
- self.drop_rate = drop_rate
-
- def forward(self, x):
- out = self.conv(x)
- if self.drop_rate > 0:
- out = F.dropout(out, p=self.drop_rate, training=self.training)
- return torch.cat([x, out], 1)
-
-
- class BottleneckLayer(nn.Sequential):
- def __init__(self, in_channels, growth_rate, drop_rate, dilation=1):
- super(BottleneckLayer, self).__init__()
-
- inter_planes = growth_rate * 4
- self.conv1 = ConvLayer(in_channels, inter_planes, kernel_size=1, padding=0)
- self.conv2 = ConvLayer(inter_planes, growth_rate, kernel_size=3, padding=dilation, dilation=dilation)
- self.drop_rate = drop_rate
-
- def forward(self, x):
- out = self.conv2(self.conv1(x))
- if self.drop_rate > 0:
- out = F.dropout(out, p=self.drop_rate, training=self.training)
- return torch.cat([x, out], 1)
-
-
- # --- dense block structure --- #
- class DenseBlock(nn.Sequential):
- def __init__(self, in_channels, growth_rate, drop_rate, layer_type, dilations):
- super(DenseBlock, self).__init__()
- for i in range(len(dilations)):
- layer = layer_type(in_channels + i * growth_rate, growth_rate, drop_rate, dilations[i])
- self.add_module('denselayer{:d}'.format(i + 1), layer)
-
-
- # --- dense block structure + attetion--- # 20200311
- class DenseAttetionBlock(nn.Sequential):
- def __init__(self, in_channels, growth_rate, drop_rate, layer_type, dilations):
- super(DenseAttetionBlock, self).__init__()
- for i in range(len(dilations)):
- layer = layer_type(in_channels + i * growth_rate, growth_rate, drop_rate, dilations[i])
- if (i % 2 == 0):
- pam = PAM(in_channels + (i + 1) * growth_rate)
- self.add_module('denselayer{:d}'.format(i + 1), layer)
- self.add_module('attentionlayer{:d}'.format(i + 1), pam)
- else:
- sam = SAM(in_channels + (i + 1) * growth_rate)
- self.add_module('denselayer{:d}'.format(i + 1), layer)
- self.add_module('attentionlayer{:d}'.format(i + 1), sam)
- # self.add_module('denselayer{:d}'.format(i+1), layer)
-
-
- def choose_hybrid_dilations(n_layers, dilation_schedule, is_hybrid):
- import numpy as np
- # key: (dilation, n_layers)
- HD_dict = {(1, 4): [1, 1, 1, 1],
- (2, 4): [1, 2, 3, 2],
- (4, 4): [1, 2, 5, 9],
- (8, 4): [3, 7, 10, 13],
- (16, 4): [13, 15, 17, 19],
- (1, 6): [1, 1, 1, 1, 1, 1],
- (2, 6): [1, 2, 3, 1, 2, 3],
- (4, 6): [1, 2, 3, 5, 6, 7],
- (8, 6): [2, 5, 7, 9, 11, 14],
- (16, 6): [10, 13, 16, 17, 19, 21]}
-
- dilation_list = np.zeros((len(dilation_schedule), n_layers), dtype=np.int32)
-
- for i in range(len(dilation_schedule)):
- dilation = dilation_schedule[i]
- if is_hybrid:
- dilation_list[i] = HD_dict[(dilation, n_layers)]
- else:
- dilation_list[i] = [dilation for k in range(n_layers)]
-
- return dilation_list
-
-
- def debug_view(x):
- x_np = x.detach().cpu().numpy()
- return '%s---%.5f---%.5f' % (str(x_np.shape), float(x_np.min()), float(x_np.max()))
-
-
- # class HNUNet_V5_9(nn.Module):
- class HanNet(nn.Module):
- def __init__(self, in_channels,
- output_channels=2,
- n_layers=6,
- growth_rate=24,
- compress_ratio=0.5,
- drop_rate=0.1,
- dilations=(1, 2, 4, 8, 16, 8, 4, 2, 1),
- is_hybrid=True,
- layer_type='basic'
- ):
- super(HanNet, self).__init__()
- if layer_type == 'basic':
- layer_type = BasicLayer
- else:
- layer_type = BottleneckLayer
-
- dilations = (2, 2, 2, 2, 2) # gpu1中的第一次结果的参数为dilations=(2,8,16,4,1)
- self.blocks = nn.Sequential()
- n_blocks = len(dilations)
-
- dilation_list = choose_hybrid_dilations(n_layers, dilations, is_hybrid)
- # print('n_blocks = ', n_blocks)
- # print('dilation_list = ')
- # print(dilation_list)
- channel_n1 = 64
- filters = [channel_n1, channel_n1 * 2, channel_n1 * 4, channel_n1 * 8,
- channel_n1 * 16] # 64, 128, 256, 512, 1024
-
- self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
- # self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
- self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
-
- # 1st conv before any dense block
- self.conv1 = ConvLayer(in_channels, filters[0], kernel_size=3, padding=1)
- # 第一层的第一次放在这里
- # self.conv0_0 = attetion_conv_block(in_channels, filters[0]) # 3, 64
-
- # channel_n1 = 64
- self.denseblock0 = DenseBlock(channel_n1 * 1, growth_rate, drop_rate, layer_type, dilation_list[3])
- # self.transblock0 = ConvLayer(channel_n1 * 1 + growth_rate * n_layers, channel_n1*2, kernel_size=1, padding=0)
- self.transblock0 = attetion_conv_block_V3(channel_n1 * 1 + growth_rate * n_layers, channel_n1 * 2)
-
- self.denseblock1 = DenseBlock(channel_n1 * 2, growth_rate, drop_rate, layer_type, dilation_list[0])
- # self.transblock1 = ConvLayer(channel_n1 * 2 + growth_rate * n_layers, channel_n1*2, kernel_size=1, padding=0)
- self.transblock1 = attetion_conv_block_V3(channel_n1 * 2 + growth_rate * n_layers, channel_n1 * 2)
-
- self.denseblock2 = DenseBlock(channel_n1 * 2, growth_rate, drop_rate, layer_type, dilation_list[1])
- # self.transblock2 = ConvLayer(channel_n1 * 2 + growth_rate * n_layers, channel_n1*2, kernel_size=1, padding=0)
- self.transblock2 = attetion_conv_block_V3(channel_n1 * 2 + growth_rate * n_layers, channel_n1 * 2)
-
- self.denseblock3 = DenseBlock(channel_n1 * 2, growth_rate, drop_rate, layer_type, dilation_list[2])
- # self.transblock3 = ConvLayer(channel_n1 * 2 + growth_rate * n_layers, channel_n1*2, kernel_size=1, padding=0)
- self.transblock3 = attetion_conv_block_V3(channel_n1 * 2 + growth_rate * n_layers, channel_n1 * 2)
-
- self.denseblock4 = DenseBlock(channel_n1 * 2, growth_rate, drop_rate, layer_type, dilation_list[3])
- # self.transblock4 = ConvLayer(channel_n1 * 2 + growth_rate * n_layers, channel_n1*2, kernel_size=1, padding=0)
- self.transblock4 = attetion_conv_block_V3(channel_n1 * 2 + growth_rate * n_layers, channel_n1 * 2)
-
- # final conv
- # self.conv2 = nn.Conv2d(channel_n1 * 4 + growth_rate * n_layers, output_channels, kernel_size=3, stride=1, padding=1, bias=False)
-
- self.conv1_0 = attetion_conv_block_V3(filters[0] * 2, filters[1]) # 64, 128
- self.conv2_0 = attetion_conv_block_V3(filters[1], filters[2]) # 128,256
- self.conv3_0 = attetion_conv_block_V3(filters[2], filters[3]) # 256,512
- self.conv4_0 = attetion_conv_block_V3(filters[3], filters[4]) # 512,1024
-
- # self.conv0_1 = attetion_conv_block(filters[0] + filters[1], filters[0])
- self.conv1_1 = attetion_conv_block_V3(filters[0] * 2 + filters[1] + filters[2], filters[1])
- self.conv2_1 = attetion_conv_block_V3(filters[1] + filters[2] + filters[3], filters[2])
- self.conv3_1 = attetion_conv_block_V3(filters[2] + filters[3] + filters[4], filters[3])
-
- # self.conv0_2 = attetion_conv_block(filters[0] * 2 + filters[1], filters[0])
- self.conv1_2 = attetion_conv_block_V3(filters[0] * 2 + filters[1] * 2 + filters[2], filters[1])
- self.conv2_2 = attetion_conv_block_V3(filters[1] + filters[2] * 2 + filters[3], filters[2])
-
- # self.conv0_3 = attetion_conv_block(filters[0] * 3 + filters[1], filters[0])
- self.conv1_3 = attetion_conv_block_V3(filters[0] * 2 + filters[1] * 3 + filters[2], filters[1])
-
- # self.conv0_4 = attetion_conv_block(filters[0] * 4 + filters[1], filters[0])
-
- self.final_conv = nn.Conv2d(64 * 2, output_channels, kernel_size=1)
- self._initialize_weights()
-
- # initialization
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.bias.data.zero_()
-
- def _initialize_weights(self):
- for module in self.modules():
- if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
- nn.init.kaiming_normal_(module.weight)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.BatchNorm2d):
- module.weight.data.fill_(1)
- module.bias.data.zero_()
-
- def forward(self, x):
- x0_0 = self.conv1(x) # [4, 64, 208, 208]
- x0_0 = self.denseblock0(x0_0)
- # x0_0 = self.csSE0_0(x0_0)
- x0_0 = self.transblock0(x0_0) # [4, 64, 208, 208]
-
- _x0_0 = self.pool(x0_0)
- x1_0 = self.conv1_0(_x0_0)
-
- _x1_0 = self.Up(x1_0)
- x1_0_up = resizeScale(x0_0, _x1_0)
- # x0_1 = self.denseblock1(torch.cat([x0_0, self.Up(x1_0)], 1)) # [4, 64*1+128+144=336, 208, 208]
- x0_1 = self.denseblock1(x0_0 + x1_0_up)
- # x0_1 = self.csSE0_1(x0_1)
- x0_1 = self.transblock1(x0_1) # [4, 64, 208, 208]
-
- x2_0 = self.conv2_0(self.pool(x1_0))
- _x2_0 = self.Up(x2_0)
- x2_0_up = resizeScale(x1_0, _x2_0)
- x1_1 = self.conv1_1(torch.cat([self.pool(x0_1), x1_0, x2_0_up], 1))
- x1_1_up = resizeScale(x0_1, self.Up(x1_1))
- # x0_2 = self.denseblock2(torch.cat([x0_1, self.Up(x1_1)], 1)) # [4, 64*1+128+144=400, 208, 208]
- x0_2 = self.denseblock2(x0_1 + x1_1_up)
- # x0_2 = self.csSE0_2(x0_2)
- x0_2 = self.transblock2(x0_2) # [4, 64, 208, 208]
-
- x3_0 = self.conv3_0(self.pool(x2_0))
- x3_0_up = resizeScale(x2_0, self.Up(x3_0))
- x2_1 = self.conv2_1(torch.cat([self.pool(x1_1), x2_0, x3_0_up], 1))
- x2_1_up = resizeScale(x1_1, self.Up(x2_1))
- x1_2 = self.conv1_2(torch.cat([self.pool(x0_2), x1_0, x1_1, x2_1_up], 1))
- x1_2_up = resizeScale(x0_2, self.Up(x1_2))
- # x0_3 = self.denseblock3(torch.cat([x0_2, self.Up(x1_2)], 1)) # [4, 64*1+128+144=464, 208, 208]
- x0_3 = self.denseblock3(x0_2 + x1_2_up)
- # x0_3 = self.csSE0_3(x0_3)
- x0_3 = self.transblock3(x0_3) # [4, 64, 208, 208]
-
- x4_0 = self.conv4_0(self.pool(x3_0))
- x4_0_up = resizeScale(x3_0, self.Up(x4_0))
- x3_1 = self.conv3_1(torch.cat([self.pool(x2_1), x3_0, x4_0_up], 1))
- x3_1_up = resizeScale(x2_1, self.Up(x3_1))
- x2_2 = self.conv2_2(torch.cat([self.pool(x1_2), x2_0, x2_1, x3_1_up], 1))
- x2_2_up = resizeScale(x1_2, self.Up(x2_2))
- x1_3 = self.conv1_3(torch.cat([self.pool(x0_3), x1_0, x1_1, x1_2, x2_2_up], 1))
- x1_3_up = resizeScale(x0_3, self.Up(x1_3))
- # x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
- # x0_4 = self.denseblock4(torch.cat([x0_3, self.Up(x1_3)], 1)) # [4, 64*1+128+144=528, 208, 208]
- x0_4 = self.denseblock4(x0_3 + x1_3_up)
- # x0_4 = self.csSE0_4(x0_4)
- x0_4 = self.transblock4(x0_4)
-
- x_final = self.final_conv(x0_4)
-
- return x_final
|