|
|
@@ -0,0 +1,140 @@ |
|
|
|
# -*- coding:utf-8 _*- |
|
|
|
""" |
|
|
|
@Author : Cui Baoyi |
|
|
|
@Time : 2021/03/14 16:00 |
|
|
|
""" |
|
|
|
''' 该文件定义了图像分割使用的Unet ''' |
|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class UNet(nn.Module): |
|
|
|
def __init__(self, in_channels=3, out_channels=1, init_features=32, WithActivateLast = True, ActivateFunLast = None): |
|
|
|
super(UNet, self).__init__() |
|
|
|
features = init_features |
|
|
|
self.WithActivateLast = WithActivateLast # True:则最后一层输出增加激活,False:最后一层输出不加激活 |
|
|
|
self.ActivateFunLast = ActivateFunLast # 如果需要激活层,设置最后激活层函数 |
|
|
|
self.encoder1 = UNet._block(in_channels, features, name="enc1") |
|
|
|
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.encoder2 = UNet._block(features, features * 2, name="enc2") |
|
|
|
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") |
|
|
|
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") |
|
|
|
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
|
|
|
|
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") |
|
|
|
|
|
|
|
self.upconv4 = nn.ConvTranspose2d( |
|
|
|
features * 16, features * 8, kernel_size=2, stride=2 |
|
|
|
) |
|
|
|
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") |
|
|
|
self.upconv3 = nn.ConvTranspose2d( |
|
|
|
features * 8, features * 4, kernel_size=2, stride=2 |
|
|
|
) |
|
|
|
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") |
|
|
|
self.upconv2 = nn.ConvTranspose2d( |
|
|
|
features * 4, features * 2, kernel_size=2, stride=2 |
|
|
|
) |
|
|
|
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") |
|
|
|
self.upconv1 = nn.ConvTranspose2d( |
|
|
|
features * 2, features, kernel_size=2, stride=2 |
|
|
|
) |
|
|
|
self.decoder1 = UNet._block(features * 2, features, name="dec1") |
|
|
|
self.conv = nn.Conv2d( |
|
|
|
in_channels=features, out_channels=out_channels, kernel_size=1 |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
enc1 = self.encoder1(x) |
|
|
|
enc2 = self.encoder2(self.pool1(enc1)) |
|
|
|
enc3 = self.encoder3(self.pool2(enc2)) |
|
|
|
enc4 = self.encoder4(self.pool3(enc3)) |
|
|
|
|
|
|
|
bottleneck = self.bottleneck(self.pool4(enc4)) |
|
|
|
|
|
|
|
dec4 = self.upconv4(bottleneck) |
|
|
|
dec4 = torch.cat((dec4, enc4), dim=1) |
|
|
|
dec4 = self.decoder4(dec4) |
|
|
|
dec3 = self.upconv3(dec4) |
|
|
|
dec3 = torch.cat((dec3, enc3), dim=1) |
|
|
|
dec3 = self.decoder3(dec3) |
|
|
|
dec2 = self.upconv2(dec3) |
|
|
|
dec2 = torch.cat((dec2, enc2), dim=1) |
|
|
|
dec2 = self.decoder2(dec2) |
|
|
|
dec1 = self.upconv1(dec2) |
|
|
|
dec1 = torch.cat((dec1, enc1), dim=1) |
|
|
|
dec1 = self.decoder1(dec1) # 2*32*256*256 |
|
|
|
if self.WithActivateLast: |
|
|
|
# return torch.sigmoid(self.conv(dec1)) # BS*1*256*256 |
|
|
|
return self.ActivateFunLast(self.conv(dec1)) |
|
|
|
else: |
|
|
|
return self.conv(dec1) # BS*1*256*256 |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
|
staticmethod用于修饰类中的方法,使其可以在不创建类实例的情况下调用方法,这样做的好处是执行效率比较高。 |
|
|
|
当然,也可以像一般的方法一样用实例调用该方法。该方法一般被称为静态方法。静态方法不可以引用类中的属性或方法, |
|
|
|
其参数列表也不需要约定的默认参数self。我个人觉得,静态方法就是类对外部函数的封装,有助于优化代码结构和提高程序的可读性。 |
|
|
|
当然了,被封装的方法应该尽可能的和封装它的类的功能相匹配。 |
|
|
|
''' |
|
|
|
@staticmethod |
|
|
|
def _block(in_channels, features, name): |
|
|
|
return nn.Sequential( |
|
|
|
OrderedDict( # 用字典的形式进行网络定义,字典key即为网络每一层的名称 |
|
|
|
[ |
|
|
|
( |
|
|
|
name + "conv1", |
|
|
|
nn.Conv2d( |
|
|
|
in_channels=in_channels, |
|
|
|
out_channels=features, |
|
|
|
kernel_size=3, |
|
|
|
padding=1, |
|
|
|
bias=False, |
|
|
|
), |
|
|
|
), |
|
|
|
(name + "norm1", nn.BatchNorm2d(num_features=features)), |
|
|
|
(name + "relu1", nn.ReLU(inplace=True)), |
|
|
|
( |
|
|
|
name + "conv2", |
|
|
|
nn.Conv2d( |
|
|
|
in_channels=features, |
|
|
|
out_channels=features, |
|
|
|
kernel_size=3, |
|
|
|
padding=1, |
|
|
|
bias=False, |
|
|
|
), |
|
|
|
), |
|
|
|
(name + "norm2", nn.BatchNorm2d(num_features=features)), |
|
|
|
(name + "relu2", nn.ReLU(inplace=True)), |
|
|
|
] |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
Input = torch.randn((2, 1, 256, 256)) # 任意生成样本 |
|
|
|
Target = torch.empty((2, 1, 256, 256), dtype=torch.long).random_(2) # 任意生成标签 |
|
|
|
|
|
|
|
Unet = UNet(in_channels=1, out_channels=2) # 2分类输出为两个通道 |
|
|
|
LossFun = nn.CrossEntropyLoss() |
|
|
|
Output = Unet(Input) |
|
|
|
print(Output.shape) |
|
|
|
print(Target.shape) |
|
|
|
|
|
|
|
# %% 官方函数计算CrossEntropyLoss |
|
|
|
BatchLoss = LossFun(Output, Target[:, 0, :, :]) # CrossEntropyLoss的Target必须没有通道的维度,即(BatchSize, W, H) |
|
|
|
print(BatchLoss) |
|
|
|
|
|
|
|
# %% 手写验证CrossEntropyLoss |
|
|
|
Errs = [] |
|
|
|
for i, Sample in enumerate(Output): # 遍历每个样本,每个像素点 |
|
|
|
for j in range(256): |
|
|
|
for k in range(256): |
|
|
|
temppredict = Output[i, :, j, k] # 每个像素的5个预测概率 |
|
|
|
temptarget = Target[i, 0, j, k] # 每个像素的真实归类 |
|
|
|
err = -temppredict[temptarget] + torch.log(torch.sum(np.e ** temppredict)) # 计算每个样本交叉熵 |
|
|
|
Errs.append(err.detach().numpy()) |
|
|
|
print(np.mean(Errs)) |