|
- import numpy as np
- # import os
- import torch #用于torchac
- # import torchvision.models as models
- # from torch.autograd import Variable
- # import torch.nn as nn
- # import torch.nn.functional as F
- # import torch.optim as optim
- # from torch.utils.data import DataLoader
- # import sys
- import math
- # import torch.nn.init as init
- # import logging
- # from torch.nn.parameter import Parameter
- from subnet import *
- import torchac
- import mindspore as ms
- import mindspore as mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import Tensor
- import imageio
-
- from vgg import Vgg19
- from discriminator import Discriminator
-
-
- out_channel_N = 64
- out_channel_M = 96
- out_channel_mv = 128
-
- def save_model(model, iter):
- mindspore.save_checkpoint(model, "/mnt/cloud_disk/ssk/project/dvc_p/snapshot/iter{}.ckpt".format(iter))
-
- def load_model(model, model_path):
- param_dict = mindspore.load_checkpoint(model_path)
- param_not_load = mindspore.load_param_into_net(model, param_dict)
- print('Load checkpoint from '+model_path)
- print("param_not_load: ", param_not_load) #打印网络中没有被加载的参数,正常应该为空
- # model_path = str(model_path)
- if model_path.find('iter') != -1 and model_path.find('.ckpt') != -1:
- st = model_path.find('iter') + 4
- ed = model_path.find('.ckpt', st)
- return int(model_path[st:ed])
- else:
- return 0
-
- class Laplace():
- def __init__(self, loc, scale):
- self.loc = loc
- self.scale = scale
-
- def cdf(self, value):
- return 0.5 - 0.5 * (value - self.loc).sign() * ops.expm1(-(value - self.loc).abs() / self.scale)
-
- class VideoCompressor(nn.Cell):
- def __init__(self, is_train=True, args=None):
- super(VideoCompressor, self).__init__()
- # self.imageCompressor = ImageCompressor()
- self.opticFlow = ME_Spynet()
- self.mvEncoder = Analysis_mv_net()
- self.Q = None
- self.mvDecoder = Synthesis_mv_net()
- self.warpnet = Warp_net()
- self.resEncoder = Analysis_net()
- self.resDecoder = Synthesis_net()
- self.respriorEncoder = Analysis_prior_net()
- self.respriorDecoder = Synthesis_prior_net()
- self.bitEstimator_z = BitEstimator(out_channel_N)
- self.bitEstimator_mv = BitEstimator(out_channel_mv)
- # self.flow_warp = Resample2d()
- # self.bitEstimator_feature = BitEstimator(out_channel_M)
- self.warp_weight = 0
- self.mxrange = 150
- self.calrealbits = False
- self.training = is_train
- self.vgg = Vgg19()
- self.discriminator = Discriminator(args.image_size[0], args.image_size[1])
-
- # def forwardFirstFrame(self, x):
- # output, bittrans = self.imageCompressor(x)
- # cost = self.bitEstimator(bittrans)
- # return output, cost
-
- def motioncompensation(self, ref, mv):
- warpframe = flow_warp(ref, mv) #来自subnet/endecoder
- inputfeature = ops.cat((warpframe, ref), 1)
- prediction = self.warpnet(inputfeature) + warpframe
- return prediction, warpframe
-
- def construct(self, input_image, referframe, quant_noise_feature=None, quant_noise_z=None,
- quant_noise_mv=None, is_train=True, global_step=1000000):
- self.training = is_train
- estmv = self.opticFlow(input_image, referframe)
- mvfeature = self.mvEncoder(estmv)
- if self.training:
- quant_mv = mvfeature + quant_noise_mv
- else:
- quant_mv = ms.ops.Rint()(mvfeature)
- quant_mv_upsample = self.mvDecoder(quant_mv)
-
- prediction, warpframe = self.motioncompensation(referframe, quant_mv_upsample)
-
- input_residual = input_image - prediction
-
- feature = self.resEncoder(input_residual)
- batch_size = feature.shape[0]
-
- z = self.respriorEncoder(feature)
-
- if self.training:
- compressed_z = z + quant_noise_z
- else:
- compressed_z = ms.ops.Rint()(z)
-
- recon_sigma = self.respriorDecoder(compressed_z)
-
- feature_renorm = feature
-
- if self.training:
- compressed_feature_renorm = feature_renorm + quant_noise_feature
- else:
- compressed_feature_renorm = ms.ops.Rint()(feature_renorm)
-
- recon_res = self.resDecoder(compressed_feature_renorm)
- recon_image = prediction + recon_res
-
- clipped_recon_image = recon_image.clamp(0., 1.)
-
-
- # distortion
- mse_loss = ops.mean(ops.pow((recon_image - input_image), 2))
-
- # psnr = tf.cond(
- # tf.equal(mse_loss, 0), lambda: tf.constant(100, dtype=tf.float32),
- # lambda: 10 * (tf.log(1 * 1 / mse_loss) / np.log(10)))
-
- warploss = ops.mean(ops.pow((warpframe - input_image), 2))
- interloss = ops.mean(ops.pow((prediction - input_image), 2))
-
-
- # bit per pixel
-
- def feature_probs_based_sigma(feature, sigma):
-
- def getrealbitsg(x, gaussian):
- # print("NIPS18noc : mn : ", torch.min(x), " - mx : ", torch.max(x), " range : ", self.mxrange)
- cdfs = []
- x = x + self.mxrange
- n,c,h,w = x.shape
- for i in range(-self.mxrange, self.mxrange):
- cdfs.append(gaussian.cdf(i - 0.5).view(n,c,h,w,1))
- cdfs = ops.stop_gradient(ops.cat(cdfs, 4))
-
- cdfs = torch.from_numpy(cdfs.asnumpy())
- x = torch.from_numpy(x.asnumpy()).to(torch.int16)
- byte_stream = torchac.encode_float_cdf(cdfs, x, check_input_bounds=True)
-
- real_bits = Tensor.from_numpy(np.array([len(byte_stream) * 8], dtype=np.float32))
-
- sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
- sym_out = Tensor.from_numpy(sym_out.float().numpy()) #转成mindspore的tensor
-
- return sym_out - self.mxrange, real_bits
-
- mu = ops.ZerosLike()(sigma)
- sigma = sigma.clamp(1e-5, 1e10)
- gaussian = Laplace(mu, sigma)
- probs = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5)
- total_bits = ops.sum(ops.clamp(-1.0 * ops.log(probs + 1e-5) / math.log(2.0), 0, 50))
-
- if self.calrealbits and not self.training:
- decodedx, real_bits = getrealbitsg(feature, gaussian)
- total_bits = real_bits
-
- return total_bits, probs
-
- def iclr18_estrate_bits_z(z):
-
- def getrealbits(x):
- cdfs = []
- x = x + self.mxrange
- n,c,h,w = x.shape
- for i in range(-self.mxrange, self.mxrange):
- cdfs.append(ms.numpy.tile(self.bitEstimator_z(i - 0.5).view(1, c, 1, 1, 1), (1, 1, h, w, 1)))
- cdfs = ops.stop_gradient(ops.cat(cdfs, 4))
-
- cdfs = torch.from_numpy(cdfs.asnumpy())
- x = torch.from_numpy(x.asnumpy()).to(torch.int16)
- byte_stream = torchac.encode_float_cdf(cdfs, x, check_input_bounds=True)
-
- real_bits = ops.sum(Tensor.from_numpy(np.array([len(byte_stream) * 8], dtype=np.float32)))
-
- sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
- sym_out = Tensor.from_numpy(sym_out.float().numpy()) #转成mindspore的tensor
-
- return sym_out - self.mxrange, real_bits
-
- prob = self.bitEstimator_z(z + 0.5) - self.bitEstimator_z(z - 0.5)
- total_bits = ops.sum(ops.clamp(-1.0 * ops.log(prob + 1e-5) / math.log(2.0), 0, 50))
-
-
- if self.calrealbits and not self.training:
- decodedx, real_bits = getrealbits(z)
- total_bits = real_bits
-
- return total_bits, prob
-
-
- def iclr18_estrate_bits_mv(mv):
-
- def getrealbits(x):
- cdfs = []
- x = x + self.mxrange
- n,c,h,w = x.shape
- for i in range(-self.mxrange, self.mxrange):
- cdfs.append(ms.numpy.tile(self.bitEstimator_mv(i - 0.5).view(1, c, 1, 1, 1), (1, 1, h, w, 1)))
- cdfs = ops.stop_gradient(ops.cat(cdfs, 4))
-
- cdfs = torch.from_numpy(cdfs.asnumpy())
- x = torch.from_numpy(x.asnumpy()).to(torch.int16)
- byte_stream = torchac.encode_float_cdf(cdfs, x, check_input_bounds=True)
-
- real_bits = ops.sum(Tensor.from_numpy(np.array([len(byte_stream) * 8], dtype=np.float32)))
-
- sym_out = torchac.decode_float_cdf(cdfs, byte_stream)
- sym_out = Tensor.from_numpy(sym_out.float().numpy()) #转成mindspore的tensor
- return sym_out - self.mxrange, real_bits
-
- prob = self.bitEstimator_mv(mv + 0.5) - self.bitEstimator_mv(mv - 0.5)
- total_bits = ops.sum(ops.clamp(-1.0 * ops.log(prob + 1e-5) / math.log(2.0), 0, 50))
-
- if self.calrealbits and not self.training:
- decodedx, real_bits = getrealbits(mv)
- total_bits = real_bits
-
- return total_bits, prob
-
- total_bits_feature, _ = feature_probs_based_sigma(compressed_feature_renorm, recon_sigma)
- # entropy_context = entropy_context_from_sigma(compressed_feature_renorm, recon_sigma)
- total_bits_z, _ = iclr18_estrate_bits_z(compressed_z)
- total_bits_mv, _ = iclr18_estrate_bits_mv(quant_mv)
-
- vgg_loss, d_loss, g_loss = None, None, None
- if global_step >= 400000:
- feat1, feat2 = self.vgg.extract_feature(input_image), self.vgg.extract_feature(recon_image)
- D_real, D_fake = self.discriminator(input_image), self.discriminator(recon_image)
- B, C, H, W = input_image.shape
- vgg_loss = ((feat1 - feat2).square().sum().sqrt() / (C * H * W)).mean()
- g_loss = 0.5 * ((D_fake - 1)**2).mean()
- d_loss = 0.5 * (((D_real - 1)**2).mean() + (D_fake**2).mean())
-
-
- im_shape = input_image.shape
- bpp_feature = total_bits_feature / (batch_size * im_shape[2] * im_shape[3])
- bpp_z = total_bits_z / (batch_size * im_shape[2] * im_shape[3])
- bpp_mv = total_bits_mv / (batch_size * im_shape[2] * im_shape[3])
- bpp = bpp_feature + bpp_z + bpp_mv
-
- return clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp, vgg_loss, g_loss, d_loss
-
- if __name__ == '__main__':
- model = VideoCompressor(is_train=True)
- im1 = imageio.imread('/userhome/DVC/PyTorch/data/vimeo_septuplet/sequences/00001/0001/im4.png')
- im1 = im1 / 255.0
- im1 = np.expand_dims(im1, axis=0)
- im2 = imageio.imread('/userhome/DVC/PyTorch/data/vimeo_septuplet/sequences/00001/0001/im2.png')
- im2 = im2 / 255.0
- im2 = np.expand_dims(im2, axis=0)
- im1 = np.transpose(im1, [0, 3, 1, 2])
- im2 = np.transpose(im2, [0, 3, 1, 2])
- im1 = ms.Tensor.from_numpy(im1).float()[...,:256]
- im2 = ms.Tensor.from_numpy(im2).float()[...,:256]
- model.set_train(True)
- print("im1.shape:",im1.shape)
- print("im2.shape:",im2.shape)
- im_height=256
- im_width=256
- out_channel_N = 64
- out_channel_M = 96
- out_channel_mv = 128
- featurenoise = ops.zeros((out_channel_M, im_height // 16, im_width // 16), dtype=ms.float32)
- znoise = ops.zeros((out_channel_N, im_height // 64, im_width // 64), dtype=ms.float32)
- mvnois = ops.zeros((out_channel_mv, im_height // 16, im_width // 16), dtype=ms.float32)
-
- quant_noise_feature = ops.ZerosLike()(featurenoise).asnumpy()
- ms.common.initializer.Uniform(scale=0.5)(quant_noise_feature)
- quant_noise_feature = np.expand_dims(quant_noise_feature, axis=0)
- quant_noise_feature = Tensor.from_numpy(quant_noise_feature)
-
- quant_noise_z = ops.ZerosLike()(znoise).asnumpy()
- ms.common.initializer.Uniform(scale=0.5)(quant_noise_z)
- quant_noise_z = np.expand_dims(quant_noise_z, axis=0)
- quant_noise_z = Tensor.from_numpy(quant_noise_z)
-
- quant_noise_mv = ops.ZerosLike()(mvnois).asnumpy()
- ms.common.initializer.Uniform(scale=0.5)(quant_noise_mv)
- quant_noise_mv = np.expand_dims(quant_noise_mv, axis=0)
- quant_noise_mv = Tensor.from_numpy(quant_noise_mv)
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = model(im1, im2, quant_noise_feature, quant_noise_z,
- quant_noise_mv, is_train=True)
- print("mse_loss:",mse_loss)
|