|
- """
- Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
- Licensed under the CC BY-NC-SA 4.0 license
- (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
- """
- import torch
- import torch.nn.functional as F
- from torch import nn
-
-
- class ResBlocks(nn.Module):
- def __init__(self, num_blocks, dim, norm, activation, pad_type):
- super(ResBlocks, self).__init__()
- self.model = []
- for i in range(num_blocks):
- self.model += [ResBlock(dim,
- norm=norm,
- activation=activation,
- pad_type=pad_type)]
- self.model = nn.Sequential(*self.model)
-
- def forward(self, x):
- return self.model(x)
-
-
- class ResBlock(nn.Module):
- def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
- super(ResBlock, self).__init__()
- model = []
- model += [Conv2dBlock(dim, dim, 3, 1, 1,
- norm=norm,
- activation=activation,
- pad_type=pad_type)]
- model += [Conv2dBlock(dim, dim, 3, 1, 1,
- norm=norm,
- activation='none',
- pad_type=pad_type)]
- self.model = nn.Sequential(*model)
-
- def forward(self, x):
- residual = x
- out = self.model(x)
- out += residual
- return out
-
-
- class ActFirstResBlock(nn.Module):
- def __init__(self, fin, fout, fhid=None,
- activation='lrelu', norm='none'):
- super().__init__()
- self.learned_shortcut = (fin != fout)
- self.fin = fin
- self.fout = fout
- self.fhid = min(fin, fout) if fhid is None else fhid
- self.conv_0 = Conv2dBlock(self.fin, self.fhid, 3, 1,
- padding=1, pad_type='reflect', norm=norm,
- activation=activation, activation_first=True)
- self.conv_1 = Conv2dBlock(self.fhid, self.fout, 3, 1,
- padding=1, pad_type='reflect', norm=norm,
- activation=activation, activation_first=True)
- if self.learned_shortcut:
- self.conv_s = Conv2dBlock(self.fin, self.fout, 1, 1,
- activation='none', use_bias=False)
-
- def forward(self, x):
- x_s = self.conv_s(x) if self.learned_shortcut else x
- dx = self.conv_0(x)
- dx = self.conv_1(dx)
- out = x_s + dx
- return out
-
-
- class LinearBlock(nn.Module):
- def __init__(self, in_dim, out_dim, norm='none', activation='relu'):
- super(LinearBlock, self).__init__()
- use_bias = True
- self.fc = nn.Linear(in_dim, out_dim, bias=use_bias)
-
- # initialize normalization
- norm_dim = out_dim
- if norm == 'bn':
- self.norm = nn.BatchNorm1d(norm_dim)
- elif norm == 'in':
- self.norm = nn.InstanceNorm1d(norm_dim)
- elif norm == 'none':
- self.norm = None
- else:
- assert 0, "Unsupported normalization: {}".format(norm)
-
- # initialize activation
- if activation == 'relu':
- self.activation = nn.ReLU(inplace=True)
- elif activation == 'lrelu':
- self.activation = nn.LeakyReLU(0.2, inplace=True)
- elif activation == 'tanh':
- self.activation = nn.Tanh()
- elif activation == 'none':
- self.activation = None
- else:
- assert 0, "Unsupported activation: {}".format(activation)
-
- def forward(self, x):
- out = self.fc(x)
- if self.norm:
- out = self.norm(out)
- if self.activation:
- out = self.activation(out)
- return out
-
-
- class Conv2dBlock(nn.Module):
- def __init__(self, in_dim, out_dim, ks, st, padding=0,
- norm='none', activation='relu', pad_type='zero',
- use_bias=True, activation_first=False):
- super(Conv2dBlock, self).__init__()
- self.use_bias = use_bias
- self.activation_first = activation_first
- # initialize padding
- if pad_type == 'reflect':
- self.pad = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- self.pad = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- self.pad = nn.ZeroPad2d(padding)
- else:
- assert 0, "Unsupported padding type: {}".format(pad_type)
-
- # initialize normalization
- norm_dim = out_dim
- if norm == 'bn':
- self.norm = nn.BatchNorm2d(norm_dim)
- elif norm == 'in':
- self.norm = nn.InstanceNorm2d(norm_dim)
- elif norm == 'adain':
- self.norm = AdaptiveInstanceNorm2d(norm_dim)
- elif norm == 'none':
- self.norm = None
- else:
- assert 0, "Unsupported normalization: {}".format(norm)
-
- # initialize activation
- if activation == 'relu':
- self.activation = nn.ReLU(inplace=True)
- elif activation == 'lrelu':
- self.activation = nn.LeakyReLU(0.2, inplace=True)
- elif activation == 'tanh':
- self.activation = nn.Tanh()
- elif activation == 'none':
- self.activation = None
- else:
- assert 0, "Unsupported activation: {}".format(activation)
-
- self.conv = nn.Conv2d(in_dim, out_dim, ks, st, bias=self.use_bias)
-
- def forward(self, x):
- if self.activation_first:
- if self.activation:
- x = self.activation(x)
- x = self.conv(self.pad(x))
- if self.norm:
- x = self.norm(x)
- else:
- x = self.conv(self.pad(x))
- if self.norm:
- x = self.norm(x)
- if self.activation:
- x = self.activation(x)
- return x
-
-
- class AdaptiveInstanceNorm2d(nn.Module):
- def __init__(self, num_features, eps=1e-5, momentum=0.1):
- super(AdaptiveInstanceNorm2d, self).__init__()
- self.num_features = num_features
- self.eps = eps
- self.momentum = momentum
- self.weight = None
- self.bias = None
- self.register_buffer('running_mean', torch.zeros(num_features))
- self.register_buffer('running_var', torch.ones(num_features))
-
- def forward(self, x):
- assert self.weight is not None and \
- self.bias is not None, "Please assign AdaIN weight first"
- b, c = x.size(0), x.size(1)
- running_mean = self.running_mean.repeat(b)
- running_var = self.running_var.repeat(b)
- x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
- out = F.batch_norm(
- x_reshaped, running_mean, running_var, self.weight, self.bias,
- True, self.momentum, self.eps)
- return out.view(b, c, *x.size()[2:])
-
- def __repr__(self):
- return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|