|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from utils.arch_util import LayerNorm2d
-
-
- ##########################################################################
- def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
- return nn.Conv2d(
- in_channels, out_channels, kernel_size,
- padding=(kernel_size // 2), bias=bias, stride=stride)
-
-
- ## Resizing modules
- class Down(nn.Module):
- def __init__(self, n_feat):
- super(Down, self).__init__()
-
- self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
- nn.PixelUnshuffle(2))
-
- def forward(self, x):
- return self.body(x)
-
-
- class Up(nn.Module):
- def __init__(self, n_feat):
- super(Up, self).__init__()
-
- self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
- nn.PixelShuffle(2))
-
- def forward(self, x):
- return self.body(x)
-
-
- class SimpleGate(nn.Module):
- def forward(self, x):
- x1, x2 = x.chunk(2, dim=1)
- return x1 * x2
-
-
- class SR(nn.Module):
- def __init__(self, n_feat, bias=False):
- super(SR, self).__init__()
- self.body = nn.Sequential(conv(n_feat//3, n_feat//3, kernel_size=1, bias=bias),
- conv(n_feat//3, n_feat//3, kernel_size=3, bias=bias))
- self.feat = n_feat
- # self.temperature1 = nn.Parameter(torch.ones(n_feat//3, 1, 1))
- # self.temperature2 = nn.Parameter(torch.ones(n_feat//3, 1, 1))
-
- # Simplified Channel Attention
- self.sca = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_channels=n_feat // 3, out_channels=n_feat // 3, kernel_size=1, padding=0, stride=1,
- groups=1, bias=True),
- )
-
- self.project_out = conv(n_feat, n_feat, kernel_size=1, bias=bias)
-
- def forward(self, x): ##24
- r = x[:, 0:(self.feat//3), :, :]
- g = x[:, (self.feat//3):(2*self.feat//3), :, :]
- b = x[:, (2*self.feat//3):(self.feat), :, :]
- r = self.body(r)
- g = self.body(g)
- b = self.body(b)
- m = self.sca(r * g)
- m = torch.sigmoid(m * b)
- m = torch.cat([m,m,m],1)
- y = self.project_out(x*m)
-
- return y
-
-
- class LKD(nn.Module):
- def __init__(self, dim, mlp_ratio=4.):
- super().__init__()
-
- # DLKCB
- # self.Linear1 = nn.Conv2d(dim, dim, 1)
- self.DWConv = nn.Conv2d(dim, dim, 5, padding=2, groups=dim, padding_mode='reflect')
- self.DWDConv1 = nn.Conv2d(dim, dim, 7, stride=1, padding=6, groups=dim, dilation=3, padding_mode='reflect')
- self.DWDConv2 = nn.Conv2d(dim, dim, 7, stride=1, padding=4, groups=dim, dilation=2, padding_mode='reflect')
- self.Linear2 = nn.Conv2d(dim*2, dim, 1)
- self.Linear3 = nn.Conv2d(dim, dim, 1)
-
- # CEFN
- # self.norm2 = nn.BatchNorm2d(dim)
- # self.cemlp = CEFN(dim=dim, hidden_features=int(mlp_ratio) * dim, out_features=dim)
- # 100 100 100 100 1
- # 101 010 101 010
- def forward(self, x):
- # x = self.Linear1(x)
- x0 = self.DWConv(x)
- x1_1 = self.DWDConv1(x0)
- x1_2 = self.DWDConv2(x0)
- x2 = torch.cat([x1_1, x1_2],1)
- x = x*self.Linear2(x2)
- x = self.Linear3(x)
-
- return x
-
-
- class NAFBlock(nn.Module):
- def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
- super().__init__()
- # SimpleGate
- self.sg = SimpleGate()
-
- self.lkd = LKD(c)
-
- ffn_channel = FFN_Expand * c
- self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1,
- bias=True)
- self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
- groups=1, bias=True)
-
- self.norm1 = LayerNorm2d(c)
- self.norm2 = LayerNorm2d(c)
-
- self.sr = SR(c)
-
- self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
- self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
-
- self.beta1 = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
- self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
-
- def forward(self, inp):
- x = inp
-
- x = self.norm1(x)
-
- s1 = self.sr(x)
- s2 = self.lkd(x)
-
- x = s1 + s2
-
- x = self.dropout1(x)
-
- y = inp + x * self.beta1
-
- x = self.conv4(self.norm2(y))
- x = self.sg(x)
- x = self.conv5(x)
-
- x = self.dropout2(x)
-
- return y + x * self.gamma
-
- class Encoder(nn.Module):
- def __init__(self, n_feat):
- super(Encoder, self).__init__()
-
- self.encoder_level1 = [NAFBlock(n_feat) for _ in range(4)]
- self.encoder_level2 = [NAFBlock(n_feat*2) for _ in range(4)]
- self.encoder_level3 = [NAFBlock(n_feat*2**2) for _ in range(6)]
- self.encoder_level4 = [NAFBlock(n_feat * 2 ** 3) for _ in range(8)]
- self.encoder_level1 = nn.Sequential(*self.encoder_level1)
- self.encoder_level2 = nn.Sequential(*self.encoder_level2)
- self.encoder_level3 = nn.Sequential(*self.encoder_level3)
- self.encoder_level4 = nn.Sequential(*self.encoder_level4)
-
- self.down1 = Down(n_feat//3)
- self.down2 = Down(n_feat*2//3)
- self.down3 = Down((n_feat * 2**2) // 3)
-
- def forward(self, r, g, b):
- f1 = torch.cat([r,g,b],dim=1)
- enc1 = self.encoder_level1(f1)
- r = self.down1(r) #32->64
- g = self.down1(g)
- b = self.down1(b)
-
- f2 = torch.cat([r,g,b],dim=1)
- enc2 = self.encoder_level2(f2)
- r = self.down2(r) #64->128
- g = self.down2(g)
- b = self.down2(b)
-
- f3 = torch.cat([r,g,b],dim=1)
- enc3 = self.encoder_level3(f3)
- r = self.down3(r) # 64->128
- g = self.down3(g)
- b = self.down3(b)
- f4 = torch.cat([r, g, b], dim=1)
- enc4 = self.encoder_level4(f4)
-
- return [enc1, enc2, enc3, enc4]
-
-
- class Decoder(nn.Module):
- def __init__(self, n_feat, bias):
- super(Decoder, self).__init__()
-
- self.decoder_level1 = [NAFBlock(n_feat) for _ in range(4)]
- self.decoder_level2 = [NAFBlock(n_feat*2) for _ in range(4)]
- self.decoder_level3 = [NAFBlock(n_feat*2**2) for _ in range(6)]
- self.decoder_level1 = nn.Sequential(*self.decoder_level1)
- self.decoder_level2 = nn.Sequential(*self.decoder_level2)
- self.decoder_level3 = nn.Sequential(*self.decoder_level3)
-
- self.leve4 = conv(int(n_feat * 2 ** 3), int(n_feat * 2**2), kernel_size=1, bias=bias)
- self.leve3 = conv(int(n_feat * 2 ** 2), int(n_feat * 2), kernel_size=1, bias=bias)
- self.leve2 = conv(int(n_feat * 2), int(n_feat), kernel_size=1, bias=bias)
-
- self.up21 = Up(n_feat*2)
- self.up32 = Up(n_feat*2**2)
- self.up43 = Up(n_feat * 2 ** 3)
-
- def forward(self, outs):
- enc1, enc2, enc3, enc4 = outs
-
- dec4 = self.up43(enc4)
- x = torch.cat([dec4, enc3], dim=1)
- x = self.leve4(x)
- dec3 = self.decoder_level3(x)
-
- dec3 = self.up32(dec3)
- x = torch.cat([dec3, enc2], dim=1)
- x = self.leve3(x)
- dec2 = self.decoder_level2(x)
-
- dec2 = self.up21(dec2)
- x = torch.cat([dec2, enc1], dim=1)
- x = self.leve2(x)
- dec1 = self.decoder_level1(x)
-
- return dec1
-
-
- ##########################################################################
- class MPRNet(nn.Module):
- def __init__(self, in_c=3, out_c=3, n_feat=48, num_cab=8, kernel_size=3, bias=False):
- super(MPRNet, self).__init__()
-
- # act = nn.PReLU()
- self.shallow_feat = conv(1, n_feat//3, kernel_size, bias=bias)
-
- self.encoder = Encoder(n_feat)
- self.decoder = Decoder(n_feat, bias)
-
- self.out = conv(n_feat, out_c, kernel_size=3)
-
- def forward(self, x_img):
-
- r,g,b = x_img.split([1,1,1],dim=1)
-
- r = self.shallow_feat(r)
- g = self.shallow_feat(g)
- b = self.shallow_feat(b)
-
- feat = self.encoder(r, g, b)
-
- xr = self.decoder(feat)
- imgr = self.out(xr)
- img = x_img - imgr
-
- return [img, imgr]
-
- ################################################################
- # dim=24
- # 17468.28 M, params: 3.36 M
-
- # dim =48
- #flops: 66482.01 M, params: 13.28 M
-
- # from thop import profile
-
- # print('==> Building model..')
- # model = MPRNet()
-
- # input = torch.randn(1, 3, 256, 256)
- # flops, params = profile(model, (input,))
- # print('flops: %.2f M, params: %.2f M' % (flops / 1e6, params / 1e6))
-
-
- # flops: 17070.87 M, params: 4.27 M --5*5
- # flops: 46391.17 M, params: 6.39 M
-
- # ===============================================================================================
- # Total params: 6,414,112
- # Trainable params: 6,414,112
- # Non-trainable params: 0
- # Total mult-adds (G): 17.50
- # ===============================================================================================
- # Input size (MB): 0.75
- # Forward/backward pass size (MB): 1287.50
- # Params size (MB): 24.47
- # Estimated Total Size (MB): 1312.72
- # ===============================================================================================
|