上一节中我们介绍了怎样定义PyTorch的模型,其中给出的示例都是用torch.nn中的层来完成的。这种定义方式易于理解,在实际场景下不一定利于使用。当模型的深度非常大时候,使用Sequential定义模型结构需要向其中添加几百行代码,使用起来不甚方便。
对于大部分模型结构(比如ResNet、DenseNet等),我们仔细观察就会发现,虽然模型有很多层, 但是其中有很多重复出现的结构。考虑到每一层有其输入和输出,若干层串联成的”模块“也有其输入和输出,如果我们能将这些重复出现的层定义为一个”模块“,每次只需要向网络中添加对应的模块来构建模型,这样将会极大便利模型构建的过程。
本节我们将以U-Net为例,介绍如何构建模型块,以及如何利用模型块快速搭建复杂模型。
经过本节的学习,你将收获:
U-Net是分割 (Segmentation) 模型的杰作,在以医学影像为代表的诸多领域有着广泛的应用。U-Net模型结构如下图所示,通过残差连接结构解决了模型学习中的退化问题,使得神经网络的深度能够不断扩展。
结合上图,不难发现U-Net模型具有非常好的对称性。模型从上到下分为若干层,每层由左侧和右侧两个模型块组成,每侧的模型块与其上下模型块之间有连接;同时位于同一层左右两侧的模型块之间也有连接,称为“Skip-connection”。此外还有输入和输出处理等其他组成部分。由于模型的形状非常像英文字母的“U”,因此被命名为“U-Net”。
组成U-Net的模型块主要有如下几个部分:
1)每个子块内部的两次卷积(Double Convolution)
2)左侧模型块之间的下采样连接,即最大池化(Max pooling)
3)右侧模型块之间的上采样连接(Up sampling)
4)输出层的处理
除模型块外,还有模型块之间的横向连接,输入和U-Net底部的连接等计算,这些单独的操作可以通过forward函数来实现。
下面我们用PyTorch先实现上述的模型块,然后再利用定义好的模型块构建U-Net模型。
在使用PyTorch实现U-Net模型时,我们不必把每一层按序排列显式写出,这样太麻烦且不宜读,一种比较好的方法是先定义好模型块,再定义模型块之间的连接顺序和计算方式。就好比装配零件一样,我们先装配好一些基础的部件,之后再用这些可以复用的部件得到整个装配体。
这里的基础部件对应上一节分析的四个模型块,根据功能我们将其命名为:DoubleConv, Down, Up, OutConv。下面给出U-Net中模型块的PyTorch 实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=False):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
使用写好的模型块,可以非常方便地组装U-Net模型。可以看到,通过模型块的方式实现了代码复用,整个模型结构定义所需的代码总行数明显减少,代码可读性也得到了提升。
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》