|
- #from base.base_model import BaseModel
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from itertools import chain
-
-
- class encoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(encoder, self).__init__()
- self.down_conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
- self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
-
- def forward(self, x):
- x = self.down_conv(x)
- x_pooled = self.pool(x)
- return x, x_pooled
-
- #nn.Upsample(scale_factor=2)
- class decoder(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(decoder, self).__init__()
- self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
- self.up_conv = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self, x_copy, x):
- x = self.up(x)
- # Padding in case the incomping volumes are of different sizes
- diffY = x_copy.size()[2] - x.size()[2]
- diffX = x_copy.size()[3] - x.size()[3]
- x = F.pad(x, (diffX // 2, diffX - diffX // 2,
- diffY // 2, diffY - diffY // 2))
- # Concatenate
- x = torch.cat([x_copy, x], dim=1)
- x = self.up_conv(x)
- return x
-
-
- class UNet(nn.Module):#BaseModel
- def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_):
- super(UNet, self).__init__()
- self.down1 = encoder(in_channels, 64)
- self.down2 = encoder(64, 128)
- self.down3 = encoder(128, 256)
- self.down4 = encoder(256, 512)
- self.middle_conv = nn.Sequential(
- nn.Conv2d(512, 1024, kernel_size=3, padding=1),
- nn.BatchNorm2d(1024),
- nn.ReLU(inplace=True),
- nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
- nn.BatchNorm2d(1024),
- nn.ReLU(inplace=True),
- )
- self.up1 = decoder(1024, 512)
- self.up2 = decoder(512, 256)
- self.up3 = decoder(256, 128)
- self.up4 = decoder(128, 64)
- self.up = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
- self.beforefinal2_conv = nn.Conv2d(128, num_classes, kernel_size=1) # 128
-
- self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
- self._initialize_weights()
- if freeze_bn:
- self.freeze_bn()
-
- def _initialize_weights(self):
- for module in self.modules():
- if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
- nn.init.kaiming_normal_(module.weight)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.BatchNorm2d):
- module.weight.data.fill_(1)
- module.bias.data.zero_()
-
- def forward(self, x):
- x1, x = self.down1(x)
- x2, x = self.down2(x)
- x3, x = self.down3(x)
- x4, x = self.down4(x)
- x = self.middle_conv(x)
- x = self.up1(x4, x)
- x = self.up2(x3, x)
- x = self.up3(x2, x)
- #x_beforefinal2_temp = self.up(x)
- #x_beforefinal2 = self.beforefinal2_conv(x_beforefinal2_temp)
-
- x = self.up4(x1, x)
- # x_beforefinal2 = self.beforefinal2_conv(x)
-
- x_final = self.final_conv(x)
- return x_final#, x_beforefinal2
-
- def get_backbone_params(self):
- # There is no backbone for unet, all the parameters are trained from scratch
- return []
-
- def get_decoder_params(self):
- return self.parameters()
-
- def freeze_bn(self):
- for module in self.modules():
- if isinstance(module, nn.BatchNorm2d): module.eval()
|