|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- from torch.nn.parameter import Parameter
- from torchvision.transforms.functional import scale
- from torchvision.utils import make_grid
- import matplotlib
- import matplotlib.cm as cm
- from PIL import Image
- from torch.utils.tensorboard import SummaryWriter
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def colorize(tensor, vmin=0, vmax=0.4, cmap="turbo"):
- assert tensor.ndim == 2
- normalizer = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
- mapper = cm.ScalarMappable(norm=normalizer, cmap=cmap)
- tensor = mapper.to_rgba(tensor)[..., :3]
- return tensor
-
- # save images to tensorboard
- def save_img(writer, tensor, tag, step, vmin=0, vmax =0.4, color=True):
- grid = make_grid(tensor.detach(), nrow=1)
- grid = grid.cpu().numpy() # CHW
- if color:
- grid = grid[0] # HW
- grid = colorize(grid, vmin, vmax).transpose(2, 0, 1) # CHW
- writer.add_image(tag, grid, step)
-
-
- class LowerBound(nn.Module):
- """
- Lower bound operator, computes `torch.max(x, bound)` with a custom
- gradient.
-
- The derivative is replaced by the identity function when `x` is moved
- towards the `bound`, otherwise the gradient is kept to zero.
- """
-
- def __init__(self, bound):
- super(LowerBound,self).__init__()
- self.register_buffer("bound", torch.Tensor([float(bound)]))
-
- @torch.jit.unused
- def lower_bound(self, x):
- return LowerBoundFunction.apply(x, self.bound)
-
- def forward(self, x):
- if torch.jit.is_scripting():
- return torch.max(x, self.bound)
- return self.lower_bound(x)
-
- class Low_bound(torch.autograd.Function):
-
- @staticmethod
- def forward(ctx, x):
- ctx.save_for_backward(x)
- x = torch.clamp(x, min=1e-6)
- return x
-
- @staticmethod
- def backward(ctx, g):
- x, = ctx.saved_tensors
- grad1 = g.clone()
- grad1[x<1e-6] = 0
- pass_through_if = np.logical_or(x.cpu().numpy() >= 1e-6,g.cpu().numpy()<0.0)
- t = torch.Tensor(pass_through_if+0.0)
- return grad1.to(device)*t.to(device)
-
- class Entropy_bottleneck(nn.Module):
- def __init__(self,channel, qstep=1,
- init_scale=10,filters = (3,3,3),likelihood_bound=1e-6,
- tail_mass=1e-9,optimize_integer_offset=True):
- super(Entropy_bottleneck,self).__init__()
-
- self.qstep=qstep
-
- self.filters = tuple(int(t) for t in filters)
- self.init_scale = float(init_scale)
- self.likelihood_bound = float(likelihood_bound)
- self.tail_mass = float(tail_mass)
-
- self.optimize_integer_offset = bool(optimize_integer_offset)
-
- if not 0 < self.tail_mass < 1:
- raise ValueError(
- "`tail_mass` must be between 0 and 1")
- filters = (1,) + self.filters + (1,)
- scale = self.init_scale ** (1.0 / (len(self.filters) + 1))
- self._matrices = nn.ParameterList([])
- self._bias = nn.ParameterList([])
- self._factor = nn.ParameterList([])
-
- for i in range(len(self.filters) + 1):
-
- init = np.log(np.expm1(1.0 / scale / filters[i + 1]))
- self.matrix = Parameter(torch.FloatTensor(channel, filters[i + 1], filters[i]))
- self.matrix.data.fill_(init)
- self._matrices.append(self.matrix)
- self.bias = Parameter(torch.FloatTensor(channel, filters[i + 1], 1))
- noise = np.random.uniform(-self.qstep/2, self.qstep/2, self.bias.size())
- noise = torch.FloatTensor(noise)
- self.bias.data.copy_(noise)
- self._bias.append(self.bias)
-
- if i < len(self.filters):
- self.factor = Parameter(torch.FloatTensor(channel, filters[i + 1], 1))
- self.factor.data.fill_(0.0)
- self._factor.append(self.factor)
-
- def _logits_cumulative(self,logits,stop_gradient):
-
-
- for i in range(len(self.filters) + 1):
-
- matrix = F.softplus(self._matrices[i])
- if stop_gradient:
- matrix = matrix.detach()
- logits = torch.matmul(matrix, logits)
-
- bias = self._bias[i]
- if stop_gradient:
- bias = bias.detach()
- logits += bias
-
- if i < len(self._factor):
- factor = torch.tanh(self._factor[i])
- if stop_gradient:
- factor = factor.detach()
- logits += factor * torch.tanh(logits)
- return logits
-
- def add_noise(self, x):
- noise = np.random.uniform(-self.qstep/2, self.qstep/2, x.size())
- noise = torch.Tensor(noise).to(device)
-
- return x + noise
-
-
- def forward(self, x,training):
- x = x.permute(1,0,2,3).contiguous()
- shape = x.size()
- x = x.view(shape[0],1,-1)
- if training:
- x = self.add_noise(x)
- else:
- x = torch.round(x/self.qstep)*self.qstep
- lower = self._logits_cumulative(x - self.qstep/2, stop_gradient=False)
- upper = self._logits_cumulative(x + self.qstep/2, stop_gradient=False)
-
- sign = -torch.sign(torch.add(lower, upper))
- sign = sign.detach()
- likelihood = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower))/self.qstep
-
- if self.likelihood_bound > 0:
- likelihood = Low_bound.apply(likelihood)
-
- likelihood = likelihood.view(shape)
- likelihood = likelihood.permute(1, 0, 2, 3)
- x = x.view(shape)
- x = x.permute(1, 0, 2, 3)
- return x, likelihood
-
-
- class NonNegativeParametrizer(nn.Module):
- """
- Non negative reparametrization.
- Used for stability during training.
- """
-
- def __init__(self, minimum=0, reparam_offset=2 ** -18):
- super(NonNegativeParametrizer,self).__init__()
-
- self.minimum = float(minimum)
- self.reparam_offset = float(reparam_offset)
-
- pedestal = self.reparam_offset ** 2
- self.register_buffer("pedestal", torch.Tensor([pedestal]))
- bound = (self.minimum + self.reparam_offset ** 2) ** 0.5
- self.lower_bound = LowerBound(bound)
-
- def init(self, x):
- return torch.sqrt(torch.max(x + self.pedestal, self.pedestal))
-
- def forward(self, x):
- out = self.lower_bound(x)
- out = out ** 2 - self.pedestal
- return out
-
- class LowerBoundFunction(torch.autograd.Function):
- """
- Autograd function for the `LowerBound` operator.
- """
-
- @staticmethod
- def forward(ctx, input_, bound):
- ctx.save_for_backward(input_, bound)
- return torch.max(input_, bound)
-
- @staticmethod
- def backward(ctx, grad_output):
- input_, bound = ctx.saved_tensors
- pass_through_if = (input_ >= bound) | (grad_output < 0)
- return pass_through_if.type(grad_output.dtype) * grad_output, None
-
-
- class RoundNoGradient(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x, qstep):
-
- return (x/qstep).round()*qstep
- @staticmethod
- def backward(ctx, g):
-
- return g
-
- class GDN(nn.Module):
- """
- Generalized Divisive Normalization layer.
-
- Introduced in `"Density Modeling of Images Using a Generalized Normalization
- Transformation" <https://arxiv.org/abs/1511.06281>`_,
- by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016).
-
- .. math::
-
- y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}}
- """
-
- def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1):
- super(GDN,self).__init__()
-
- beta_min = float(beta_min)
- gamma_init = float(gamma_init)
- self.inverse = bool(inverse)
-
- self.beta_reparam = NonNegativeParametrizer(minimum=beta_min)
- beta = torch.ones(in_channels)
- beta = self.beta_reparam.init(beta)
- self.beta = nn.Parameter(beta)
-
- self.gamma_reparam = NonNegativeParametrizer()
- gamma = gamma_init * torch.eye(in_channels)
- gamma = self.gamma_reparam.init(gamma)
- self.gamma = nn.Parameter(gamma)
-
- def forward(self, x):
- _, C, _, _ = x.size()
-
- beta = self.beta_reparam(self.beta)
- gamma = self.gamma_reparam(self.gamma)
- gamma = gamma.reshape(C, C, 1, 1)
- norm = torch.conv2d(x ** 2, gamma, beta)
-
- if self.inverse:
- norm = torch.sqrt(norm)
- else:
- norm = torch.rsqrt(norm)
-
- out = x * norm
-
- return out
-
-
- def conv3x3(in_ch, out_ch, stride=1):
- """
- 3x3 convolution with padding.
- """
- return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
-
- def subpel_conv3x3(in_ch, out_ch, r=1):
- """
- 3x3 sub-pixel convolution for up-sampling.
- """
- return nn.Sequential(
- nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
- )
-
- def conv1x1(in_ch, out_ch, stride=1):
- """
- 1x1 convolution.
- """
- return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
-
- class ResidualBlockWithStride(nn.Module):
- """
- Residual block with a stride on the first convolution.
-
- Args:
- in_ch (int): number of input channels
- out_ch (int): number of output channels
- stride (int): stride value (default: 2)
- """
-
- def __init__(self, in_ch, out_ch, stride=2):
- super(ResidualBlockWithStride,self).__init__()
- self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
- self.leaky_relu = nn.LeakyReLU(inplace=True)
- self.conv2 = conv3x3(out_ch, out_ch)
- self.gdn = GDN(out_ch)
- if stride != 1 or in_ch != out_ch:
- self.skip = conv1x1(in_ch, out_ch, stride=stride)
- else:
- self.skip = None
-
- def forward(self, x):
- identity = x
- out = self.conv1(x)
- out = self.leaky_relu(out)
- out = self.conv2(out)
- out = self.gdn(out)
-
- if self.skip is not None:
- identity = self.skip(x)
-
- out = out+identity
- return out
-
- class ResidualBlockUpsample(nn.Module):
- """
- esidual block with sub-pixel upsampling on the last convolution.
-
- Args:
- in_ch (int): number of input channels
- out_ch (int): number of output channels
- upsample (int): upsampling factor (default: 2)
- """
-
- def __init__(self, in_ch, out_ch, upsample=2):
- super(ResidualBlockUpsample,self).__init__()
- self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
- self.leaky_relu = nn.LeakyReLU(inplace=True)
- self.conv = conv3x3(out_ch, out_ch)
- self.igdn = GDN(out_ch, inverse=True)
- self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
-
- def forward(self, x):
- identity = x
- out = self.subpel_conv(x)
- out = self.leaky_relu(out)
- out = self.conv(out)
- out = self.igdn(out)
- identity = self.upsample(x)
- out = out+identity
- return out
-
-
- class ResidualBlock(nn.Module):
- """
- Simple residual block with two 3x3 convolutions.
-
- Args:
- in_ch (int): number of input channels
- out_ch (int): number of output channels
- """
-
- def __init__(self, in_ch, out_ch):
- super(ResidualBlock,self).__init__()
- self.conv1 = conv3x3(in_ch, out_ch)
- self.leaky_relu = nn.LeakyReLU(inplace=True)
- self.conv2 = conv3x3(out_ch, out_ch)
- if in_ch != out_ch:
- self.skip = conv1x1(in_ch, out_ch)
- else:
- self.skip = None
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.leaky_relu(out)
- out = self.conv2(out)
- out = self.leaky_relu(out)
-
- if self.skip is not None:
- identity = self.skip(x)
-
- out = out + identity
- return out
-
- class ResidualBlockWithGDN(nn.Module):
- """
- Simple residual block with two 3x3 convolutions.
-
- Args:
- in_ch (int): number of input channels
- out_ch (int): number of output channels
- """
-
- def __init__(self, in_ch, out_ch):
- super(ResidualBlockWithGDN,self).__init__()
- self.conv1 = conv3x3(in_ch, out_ch)
- self.leaky_relu = nn.LeakyReLU(inplace=True)
- self.conv2 = conv3x3(out_ch, out_ch)
- if in_ch != out_ch:
- self.skip = conv1x1(in_ch, out_ch)
- else:
- self.skip = None
- self.gdn = GDN(out_ch)
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.gdn(out)
- out = self.leaky_relu(out)
- out = self.conv2(out)
- out = self.leaky_relu(out)
-
- if self.skip is not None:
- identity = self.skip(x)
-
- out = out + identity
- return out
-
- class ResidualUnit(nn.Module):
- """
- Simple residual unit.
- """
-
- def __init__(self, N):
- super(ResidualUnit,self).__init__()
- self.conv = nn.Sequential(
- conv1x1(N, N // 2),
- nn.ReLU(inplace=True),
- conv3x3(N // 2, N // 2),
- nn.ReLU(inplace=True),
- conv1x1(N // 2, N),
- )
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- identity = x
- out = self.conv(x)
- out = out+identity
- out = self.relu(out)
- return out
-
- class AttentionBlock(nn.Module):
- """
- Self attention block.
-
- Simplified variant from `"Learned Image Compression with
- Discretized Gaussian Mixture Likelihoods and Attention Modules"
- <https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
- Takeuchi, Jiro Katto.
-
- Args:
- N (int): Number of channels)
- """
-
- def __init__(self, N):
- super(AttentionBlock,self).__init__()
- self.conv_a = nn.Sequential(
- ResidualUnit(N),
- ResidualUnit(N),
- ResidualUnit(N))
-
- self.conv_b = nn.Sequential(
- ResidualUnit(N),
- ResidualUnit(N),
- ResidualUnit(N),
- conv1x1(N, N),
- )
-
- def forward(self, x):
- identity = x
- a = self.conv_a(x)
- b = self.conv_b(x)
- activation = a * torch.sigmoid(b)
- out = activation + identity
- return out
-
-
- class Distribution_for_GMM_entropy(nn.Module):
- """
- 三个高斯分布混合
- """
- def __init__(self,N, qstep=1):
- super(Distribution_for_GMM_entropy,self).__init__()
- self.N=N
- self.qstep=qstep
-
- def forward(self, x, gmm_params):
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2=\
- torch.split(gmm_params, self.N, dim=1)
-
- scale0 = torch.exp(scale0)
- scale1 = torch.exp(scale1)
- scale2 = torch.exp(scale2)
-
- scale0 = torch.clamp(scale0, min=1e-6)
- scale1 = torch.clamp(scale1, min=1e-6)
- scale2 = torch.clamp(scale2, min=1e-6)
-
- #pytorch: dataloader按batch_size 取的输入image的shape 是(b, c , h , w ),tensorflow是(b,h,w,c)。
- probs = torch.stack([prob0, prob1, prob2], dim=1)
- probs = torch.softmax(probs, dim=1)
-
- m0 = torch.distributions.normal.Normal(mean0,scale0)
- m1 = torch.distributions.normal.Normal(mean1,scale1)
- m2 = torch.distributions.normal.Normal(mean2,scale2)
-
- #GMM
- likelihood0 = torch.abs(m0.cdf(x+self.qstep/2)-m0.cdf(x-self.qstep/2))/self.qstep
- likelihood1 = torch.abs(m1.cdf(x+self.qstep/2)-m1.cdf(x-self.qstep/2))/self.qstep
- likelihood2 = torch.abs(m2.cdf(x+self.qstep/2)-m2.cdf(x-self.qstep/2))/self.qstep
-
- likelihoods = probs[:,0,:,:,:]*likelihood0 + probs[:,1,:,:,:]*likelihood1 \
- + probs[:,2,:,:,:]*likelihood2
-
-
- #GMM稳定版本
- edge_min = probs[:,0,:,:,:]*m0.cdf(x+self.qstep/2) + probs[:,1,:,:,:]*m1.cdf(x+self.qstep/2) \
- + probs[:,2,:,:,:]*m2.cdf(x+self.qstep/2)
- edge_max = probs[:,0,:,:,:]*( 1- m0.cdf(x-self.qstep/2)) + probs[:,1,:,:,:]*(1-m1.cdf(x-self.qstep/2)) \
- + probs[:,2,:,:,:]*(1-m2.cdf(x-self.qstep/2))
- likelihoods = torch.where(x < -254.5, edge_min, torch.where(x > 255.5, edge_max, likelihoods))
-
-
- likelihoods = Low_bound.apply(likelihoods)
-
- return likelihoods
-
- class Distribution_for_Residual_entropy(nn.Module):
- """
- Residual的分布
- """
- def __init__(self, qstep=1):
- super(Distribution_for_Residual_entropy,self).__init__()
- self.qstep=qstep
-
- def forward(self, x, prob0,prob1,prob2,mean0,mean1, mean2,scale0, scale1, scale2):
-
- # scale0 = torch.exp(scale0)
- # scale1 = torch.exp(scale1)
- # scale2 = torch.exp(scale2)
-
- scale0 = torch.clamp(scale0, min=1e-6)
- scale1 = torch.clamp(scale1, min=1e-6)
- scale2 = torch.clamp(scale2, min=1e-6)
-
- #pytorch: dataloader按batch_size 取的输入image的shape 是(b, c , h , w ),tensorflow是(b,h,w,c)。
- probs = torch.stack([prob0, prob1, prob2], dim=1)
- probs = torch.softmax(probs, dim=1)
-
-
- m0 = torch.distributions.laplace.Laplace(mean0,scale0)
- m1 = torch.distributions.laplace.Laplace(mean1,scale1)
- m2 = torch.distributions.laplace.Laplace(mean2,scale2)
- """
- # Building a Logistic Distribution
- # X ~ Uniform(0, 1)
- # f = a + b * logit(X)
- # Y ~ f(X) ~ Logistic(a, b)
-
- base_distribution = torch.distributions.uniform.Uniform(torch.zeros_like(mean0), torch.ones_like(scale0))
-
- transforms_1 = [torch.distributions.transforms.SigmoidTransform().inv, torch.distributions.transforms.AffineTransform(loc=mean0, scale=scale0)]
- m0 = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms_1)
-
- transforms_2 = [torch.distributions.transforms.SigmoidTransform().inv, torch.distributions.transforms.AffineTransform(loc=mean1, scale=scale1)]
- m1 = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms_2)
-
- transforms_3 = [torch.distributions.transforms.SigmoidTransform().inv, torch.distributions.transforms.AffineTransform(loc=mean2, scale=scale2)]
- m2 = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms_3)
- """
-
-
- likelihood0 = torch.abs(m0.cdf(x+self.qstep/2)-m0.cdf(x-self.qstep/2))/self.qstep
- likelihood1 = torch.abs(m1.cdf(x+self.qstep/2)-m1.cdf(x-self.qstep/2))/self.qstep
- likelihood2 = torch.abs(m2.cdf(x+self.qstep/2)-m2.cdf(x-self.qstep/2))/self.qstep
-
- likelihoods = probs[:,0,:,:,:]*likelihood0 + probs[:,1,:,:,:]*likelihood1 \
- + probs[:,2,:,:,:]*likelihood2
-
-
- #稳定版本
- edge_min = probs[:,0,:,:,:]*m0.cdf(x+self.qstep/2) + probs[:,1,:,:,:]*m1.cdf(x+self.qstep/2) \
- + probs[:,2,:,:,:]*m2.cdf(x+self.qstep/2)
- edge_max = probs[:,0,:,:,:]*( 1- m0.cdf(x-self.qstep/2)) + probs[:,1,:,:,:]*(1-m1.cdf(x-self.qstep/2)) \
- + probs[:,2,:,:,:]*(1-m2.cdf(x-self.qstep/2))
- likelihoods = torch.where(x < -32768, edge_min, torch.where(x > 32768, edge_max, likelihoods))
-
-
- likelihoods = Low_bound.apply(likelihoods)
-
- return likelihoods
-
- class Cheng2020Attention(nn.Module):
- """
- Self-attention model variant from `"Learned Image Compression with
- Discretized Gaussian Mixture Likelihoods and Attention Modules"
- <https://arxiv.org/abs/2001.01568>`_, by Zhengxue Cheng, Heming Sun, Masaru
- Takeuchi, Jiro Katto.
-
- Uses self-attention, residual blocks with small convolutions (3x3 and 1x1),
- and sub-pixel convolutions for up-sampling.
-
- Args:
- N (int): Number of channels,low_rate:128 high_rate:192
- """
-
- def __init__(self, N=128, init_weights=True, qstep=1):
- super(Cheng2020Attention,self).__init__()
-
- self.qstep = qstep
- self.factorized_entropy_func = Entropy_bottleneck(N, qstep=qstep)
- self.encoder = Cheng2020AttentionEncoder(N=N, init_weights=init_weights)
- self.decoder = Cheng2020AttentionDecoder(N=N, init_weights=init_weights)
-
- self.GMM = Cheng2020AttentionGMM(N=N, init_weights=init_weights)
- self.hyperDecoder = Cheng2020AttentionHyperDecoder(N=N, init_weights=init_weights)
- self.hyperEncoder = Cheng2020AttentionHyperEncoder(N=N, init_weights=init_weights)
-
-
- def add_noise(self, x):
- noise = np.random.uniform(-self.qstep/2, self.qstep/2, x.size())
- noise = torch.Tensor(noise).to(device)
- return x + noise
-
-
- def forward(self, x, if_training):
- """
- xq2对应z_hat, xp2对应z_likehoods
- xq1对应y_hat,
- """
-
- x1 = self.encoder(x, if_training)
- x2 = self.hyperEncoder(x1, if_training)
-
- xq2,xp2 = self.factorized_entropy_func(x2,if_training)
- params = self.hyperDecoder(xq2, if_training)
-
- if if_training:
- xq1 = self.add_noise(x1)
- else:
- xq1 = RoundNoGradient.apply(x1, self.qstep)
-
- xp1 = self.GMM(xq1, params, if_training)
- output = self.decoder(xq1, if_training)
-
-
- return [output,xp1,xp2,xq1]
-
-
- class Cheng2020AttentionEncoder(nn.Module):
- """
- Encoder
- """
- def __init__(self, N=128, init_weights=True):
- super(Cheng2020AttentionEncoder,self).__init__()
-
- self.g_a = nn.Sequential(
- ResidualBlockWithStride(2, N, stride=2),
- ResidualBlockWithStride(N, N, stride=2),
- ResidualBlock(N, N),
- ResidualBlockWithStride(N, N, stride=2),
- AttentionBlock(N),
- ResidualBlockWithStride(N, N, stride=2),
- AttentionBlock(N),
- )
-
- def forward(self, x, if_training):
- x1 = self.g_a(x)
-
- return x1
-
- class Cheng2020AttentionHyperEncoder(nn.Module):
- """
- Encoder
- """
- def __init__(self, N=128, init_weights=True):
- super(Cheng2020AttentionHyperEncoder,self).__init__()
-
-
- self.h_a = nn.Sequential(
- conv3x3(N, N),
- nn.LeakyReLU(inplace=True),
- conv3x3(N, N),
- nn.LeakyReLU(inplace=True),
- conv3x3(N, N, stride=2),
- nn.LeakyReLU(inplace=True),
- conv3x3(N, N),
- nn.LeakyReLU(inplace=True),
- conv3x3(N, N, stride=2),
- )
-
-
- def forward(self, x1, if_training):
- x2 = self.h_a(x1)
-
- return x2
-
- class Cheng2020AttentionDecoder(nn.Module):
- """
- Decoder
- """
- def __init__(self, N=128, init_weights=True):
- super(Cheng2020AttentionDecoder,self).__init__()
-
- self.g_s = nn.Sequential(
- AttentionBlock(N),
- ResidualBlockUpsample(N, N, 2),
- AttentionBlock(N),
- ResidualBlockUpsample(N, N, 2),
- ResidualBlock(N, N),
- ResidualBlockUpsample(N, N, 2),
- ResidualBlockUpsample(N, 2, 2),
- )
-
-
- def forward(self, xq1, if_training):
- output = self.g_s(xq1)
- return output
-
- class Cheng2020AttentionHyperDecoder(nn.Module):
- """
- HyperDecoder
- """
- def __init__(self, N=128, init_weights=True):
- super(Cheng2020AttentionHyperDecoder,self).__init__()
-
- self.h_s = nn.Sequential(
- conv3x3(N, N),
- nn.LeakyReLU(inplace=True),
- subpel_conv3x3(N, N, 2),
- nn.LeakyReLU(inplace=True),
- conv3x3(N, N * 3 // 2),
- nn.LeakyReLU(inplace=True),
- subpel_conv3x3(N * 3 // 2, N * 3 , 2),
- nn.LeakyReLU(inplace=True),
- conv3x3(N * 3 , N * 9),
- )
-
-
- def forward(self, xq2, if_training):
- params = self.h_s(xq2)
- return params
-
- class Cheng2020AttentionGMM(nn.Module):
- """
- GMM entropy module
- """
-
- def __init__(self, N=128, init_weights=True, qstep=1):
- super(Cheng2020AttentionGMM,self).__init__()
- self.gaussin_entropy_func=Distribution_for_GMM_entropy(N, qstep=qstep)
-
- def forward(self, xq1, params, if_training):
- xp1 = self.gaussin_entropy_func(xq1, params)
- return xp1
-
-
- class FusionModel(nn.Module):
- """
- 混合网络
- """
- def __init__(self, N=64, init_weights=True, qstep=1):
- super(FusionModel,self).__init__()
- self.N = N
- self.qstep = qstep
- self.paramsModel_1 = nn.Sequential(
- ResidualBlock(N, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N),
- )
- self.paramsModel_2 = nn.Sequential(
- ResidualBlock(N, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N),
- )
-
- def forward(self, lossyDepth, residual_esti_1):
- params_1 = self.paramsModel_1(lossyDepth)
- params_2 = self.paramsModel_2(lossyDepth)
- params_output = residual_esti_1*params_1+params_2
-
- return params_output
-
- class LossyDepthParamsModel(nn.Module):
- """
- 学习参数
- """
- def __init__(self, N=64, init_weights=True, qstep=1):
- super(LossyDepthParamsModel,self).__init__()
- self.N = N
- self.qstep = qstep
- self.paramsModel_1 = nn.Sequential(
- ResidualBlock(2, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N),
- )
- self.paramsModel_2 = nn.Sequential(
- conv3x3(N, N, stride=2),
- ResidualBlock(N, N),
- AttentionBlock(N),
- ResidualBlock(N, N),
- subpel_conv3x3(N, N, r=2)
- )
- self.paramsModel_3 = nn.Sequential(
- ResidualBlock(2*N, N),
- ResidualBlockWithGDN(N, N),
- ResidualBlock(N, N)
- )
-
-
- def forward(self, input):
- params_1 = self.paramsModel_1(input)
- params_2 = self.paramsModel_2(params_1)
- params_output = self.paramsModel_3(torch.cat((params_2, params_1), dim=1))
- return params_output
-
- class ResidualModel(nn.Module):
- """
- residual model
- """
- def __init__(self, N=64, init_weights=True, qstep=1):
- super(ResidualModel,self).__init__()
- # self.ParamsModel_1 = ResidualParamsModel(N=N, init_weights=init_weights, qstep=qstep)
-
- self.residual_entropy = Distribution_for_Residual_entropy()
- self.lossyDepthParamsBlock= LossyDepthParamsModel(N)
-
-
- self.paramBlock_2= nn.Sequential(
- ResidualBlock(2, N),
- AttentionBlock(N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- )
-
- self.probBlock= nn.Sequential(
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, 6),
- )
- self.meanBlock= nn.Sequential(
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, 6),
- )
- self.scaleBlock= nn.Sequential(
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, N),
- ResidualBlock(N, 6),
- )
-
- self.fusionModel = FusionModel(N)
-
-
-
- def forward(self, lossyDepth, residual_esti_1 , x):
- """
- x: 要深度熵编码的对象
- lossyDepth: 用于计算概率分布的有损深度图
- """
- lossyDepth = self.lossyDepthParamsBlock(lossyDepth)
-
- residual_esti_1 = self.paramBlock_2(residual_esti_1)
-
- input= self.fusionModel(lossyDepth, residual_esti_1)
-
- probs=self.probBlock(input)
- means=self.meanBlock(input)
- scales=self.scaleBlock(input)
-
- prob0,prob1,prob2 = torch.split(probs, 2, dim=1)
- mean0,mean1, mean2 = torch.split(means, 2, dim=1)
- scale0, scale1, scale2 =torch.split(scales, 2, dim=1)
-
- xp = self.residual_entropy(x,prob0,prob1,prob2,mean0,mean1, mean2,scale0, scale1, scale2)
- return xp
-
- class CompressionModel(nn.Module):
- """
- 端到端有损+无损
- """
- def __init__(self, depth_N=128, residual_N=64, scale=100, qstep_residual=1000, init_weights=True, qstep=1):
- super(CompressionModel,self).__init__()
- self.depthModel = Cheng2020Attention(N=depth_N)
-
- self.qstep_residual = qstep_residual
- self.residualModel = ResidualModel(N=residual_N)
-
- self.qstep = qstep
- self.scale = scale
-
- if init_weights:
- self._initialize_weights()
-
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
-
- def add_noise(self, x):
- noise = np.random.uniform(-self.qstep/2, self.qstep/2, x.size())
- noise = torch.Tensor(noise).to(device)
- return x + noise
-
- def forward(self, x, if_training=False):
- lossyDepth, xp1, xp2, _ = self.depthModel(x, if_training)
- residual = x - lossyDepth
-
- # depth残差还原为原本的精度
- residual_real = residual*self.scale
-
- if if_training:
- residual_real = self.add_noise(residual_real*self.qstep_residual)
- else:
- residual_real = RoundNoGradient.apply(residual_real*self.qstep_residual, self.qstep)
-
- # 再次传入预测residual,不需要传入hyperprior部分
- # lossyDepth_tmp, _, _, _ = self.depthModel(lossyDepth, if_training)
- lossyDepth_x1_tmp = self.depthModel.encoder(lossyDepth, if_training)
- if if_training:
- lossyDepth_xq1_tmp = self.add_noise(lossyDepth_x1_tmp)
- else:
- lossyDepth_xq1_tmp = RoundNoGradient.apply(lossyDepth_x1_tmp, self.qstep)
- lossyDepth_tmp = self.depthModel.decoder(lossyDepth_xq1_tmp, if_training)
- residual_esti_1 = lossyDepth-lossyDepth_tmp
-
-
- xp_residual_1 = self.residualModel(lossyDepth, residual_esti_1 ,residual_real)
-
- return [lossyDepth, residual, residual_esti_1, xp1, xp2, xp_residual_1]
-
-
- if __name__ == "__main__":
- from torchstat import stat
- model = CompressionModel(depth_N=64, residual_N=64, qstep_residual=1000).to(device)
- print(model)
- inputs = torch.abs(torch.randn(1,2,64, 64)).to(device)
- lossyDepth, residual, residual_esti_1, xp1, xp2, xp_residual_1 = model(inputs, if_training=False)
- print(lossyDepth.shape)
- print(xp1.shape)
- print(xp2.shape)
- print(xp_residual_1.shape)
- print(residual.shape)
-
- stat(model, (2, 64, 64))
|