|
- import math
- import io
- import torch
- from torchvision import transforms
- import numpy as np
-
- from PIL import Image
-
- import matplotlib.pyplot as plt
- from compressai.zoo import bmshj2018_hyperprior
- import mindspore as ms
- from mindspore import save_checkpoint, Tensor
- from mindspore import load_checkpoint, load_param_into_net
-
- net = bmshj2018_hyperprior(quality=2, pretrained=False)
-
- # trans from pth to ckpt
- # dict_route = 'compressai/models/bmshj2018-hyperprior-2.pth'
- # torch_param_dict=torch.load(dict_route)
- # print(torch_param_dict.keys())
- # print(len(torch_param_dict.keys()))
-
- # ms_name_list = []
- # for m in net.get_parameters():
- # if not 'offset' in m.name and not 'cdf' in m.name:
- # print(m.name)
- # ms_name_list.append(m.name)
- # ms_name_list.pop(-1)
- # ms_name_list.pop(-1)
- # ms_name_list.append('stub')
- # param_list = []
- # i = 0
- # for key in torch_param_dict.keys():
- # tmp = ms_name_list[i]
- # if key.split('.')[-1] == ms_name_list[i].split('.')[-1]:
- # value = torch_param_dict[key]
- # print(value.shape)
- # value = Tensor(value.numpy())
- # # convtranspose3d升维
- # if 'deconv' in ms_name_list[i] and 'weight' in ms_name_list[i]:
- # value = ms.ops.ExpandDims()(value, 2)
- # param_list.append({"name": ms_name_list[i], "data": value})
- # i += 1
- # save_checkpoint(param_list, 'compressai/models/convert_bmshj2018-hyperprior-2.ckpt')
-
- # load pretrain
- param_dir = 'compressai/models/convert_bmshj2018-hyperprior-2.ckpt'
- param_dict = load_checkpoint(param_dir)
- load_param_into_net(net, param_dict)
-
- # input image
- img = Image.open('assets/sea.jpg').convert('RGB')
- totensor=ms.dataset.vision.ToTensor()
- ms_x = Tensor(totensor(img))
- ms_x = ms_x.unsqueeze(0)
- print(ms_x.shape)
-
- # pad to 64 muiltiple
- def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
- """Returns tuples for padding and unpadding.
- Args:
- in_h: Input height.
- in_w: Input width.
- out_h: Output height.
- out_w: Output width.
- min_div: Length that output dimensions should be divisible by.
- """
- if out_h is None:
- out_h = (in_h + min_div - 1) // min_div * min_div
- if out_w is None:
- out_w = (in_w + min_div - 1) // min_div * min_div
-
- if out_h % min_div != 0 or out_w % min_div != 0:
- raise ValueError(
- f"Padded output height and width are not divisible by min_div={min_div}."
- )
-
- left = (out_w - in_w) // 2
- right = out_w - in_w - left
- top = (out_h - in_h) // 2
- bottom = out_h - in_h - top
-
- pad = (left, right, top, bottom)
- unpad = (-left, -right, -top, -bottom)
-
- return pad, unpad
-
- def pad(x, p=2**6):
- x_s = x.shape
- h, w = x_s[2], x_s[3]
- pad, _ = compute_padding(h, w, min_div=p)
- return ms.ops.pad(x, pad, mode="constant", value=0)
-
- p = 64
- ms_x = pad(ms_x, p)
-
- # forward
- ms_out=net.construct(ms_x)
-
- # evaluate
- def ms_compute_psnr(a, b):
- ms_psnr = ms.nn.PSNR()
- return ms_psnr(a, b)
-
- def ms_compute_msssim(a, b):
- m_ssim = ms.nn.SSIM()
- return m_ssim(a, b)
-
- def ms_compute_bpp(out_net):
- size = [1, 3, 480, 720]
- num_pixels = size[0] * size[2] * size[3]
- ms_log = ms.ops.Log()
- return sum(ms_log(likelihoods).sum() / (-math.log(2) * num_pixels) for likelihoods in out_net[1][0])
-
- print(f'mindspore PSNR: {ms_compute_psnr(ms_x, ms_out[0])[0]}dB')
- print(f'mindspore MS-SSIM: {ms_compute_msssim(ms_x, ms_out[0])[0]}')
- print(f'mindspore Bit-rate: {ms_compute_bpp(list(ms_out))} bpp')
|