|
- import numpy as np
- import os
- # import torch
- # import torch.nn as nn
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore.common.initializer import Uniform, initializer
- import math
- from models import *
-
-
- def save_model(model, iter, name):
- ms.save_checkpoint(model, os.path.join(name, f"iter_{iter}.ckpt"))
-
-
- def load_model(model, f):
- pretrained_dict = ms.load_checkpoint(f)
- # model_dict = model.state_dict()
- # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
- # model_dict.update(pretrained_dict)
- # model.load_state_dict(model_dict)
- param_not_load = ms.load_param_into_net(model, pretrained_dict)
- print("param not loaded:", param_not_load)
- if f.find('iter_') != -1 and f.find('.ckpt') != -1:
- st = f.find('iter_') + 5
- ed = f.find('.ckpt', st)
- return int(f[st:ed])
- else:
- return 0
-
-
- class ImageCompressor(nn.Cell):
- def __init__(self, out_channel_N=128):
- super(ImageCompressor, self).__init__()
- self.Encoder = Analysis_net_17(out_channel_N=out_channel_N)
- self.Decoder = Synthesis_net_17(out_channel_N=out_channel_N)
- self.bitEstimator = BitEstimator(channel=out_channel_N)
- self.out_channel_N = out_channel_N
-
- def construct(self, input_image):
- zeros = ms.ops.Zeros()
- input_shape = input_image.shape
- quant_noise_feature = zeros((input_shape[0], self.out_channel_N, input_shape[2] // 16, input_shape[3] // 16),
- ms.float32)
- quant_noise_feature = initializer(Uniform(0.5), quant_noise_feature.shape)
- feature = self.Encoder(input_image)
- print(feature.shape[0])
- batch_size = feature.shape[0]
- feature_renorm = feature
- if self.training:
- compressed_feature_renorm = feature_renorm + quant_noise_feature
- else:
- round_fuc = ms.ops.Rint()
- compressed_feature_renorm = round_fuc(feature_renorm)
- recon_image = self.Decoder(compressed_feature_renorm)
- # recon_image = prediction + recon_res
- clipped_recon_image = recon_image.clamp(0., 1.)
- # distortion
- ms_mean = ms.ops.ReduceMean()
- mse_loss = ms_mean((recon_image - input_image).pow(2))
-
- # def feature_probs_based_sigma(feature, sigma):
- # mu = torch.zeros_like(sigma)
- # sigma = sigma.clamp(1e-10, 1e10)
- # gaussian = torch.distributions.laplace.Laplace(mu, sigma)
- # probs = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5)
- # total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-10) / math.log(2.0), 0, 50))
- # return total_bits, probs
-
- def iclr18_estimate_bits_z(z):
- prob = self.bitEstimator(z + 0.5) - self.bitEstimator(z - 0.5)
- ms_reducesum = ms.ops.ReduceSum()
- ms_log = ms.ops.Log()
- total_bits = ms_reducesum(ms.ops.clip_by_value(-1.0 * ms_log(prob + 1e-10) / math.log(2.0), 0, 50))
- return total_bits, prob
-
- total_bits_feature, _ = iclr18_estimate_bits_z(compressed_feature_renorm)
- im_shape = input_image.shape
- bpp_feature = total_bits_feature / (batch_size * im_shape[2] * im_shape[3])
-
- return clipped_recon_image, mse_loss, bpp_feature
-
-
- if __name__ == '__main__':
- model = ImageCompressor()
- zeros = ms.ops.Zeros()
- data = zeros((4, 3, 256, 256), ms.float32)
- print(data.shape)
- out = model(data)
- print(out)
- optimizer = nn.Adam(model.trainable_params(), lr=1e-4)
- loss_fn = nn.MSELoss()
-
-
- def forward_fn(data, label):
- logits = model(data)[0]
- loss = loss_fn(logits[0], label)
- return loss, logits
-
-
- grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
-
-
- def train_step(data, label):
- (loss, _), grads = grad_fn(data, label)
- optimizer(grads)
- return loss
-
-
- model.set_train()
- loss = train_step(data, data)
- print(loss.asnumpy())
|