|
- import math
- import io
- import torch
- from torchvision import transforms
- import numpy as np
- import torch.nn.functional as F
- from math import exp
-
- from PIL import Image
-
- import matplotlib.pyplot as plt
- #from pytorch_msssim import ms_ssim
- from torch_compressai.zoo import bmshj2018_hyperprior, bmshj2018_factorized
-
- # pad to 64 muiltiple
- def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
- """Returns tuples for padding and unpadding.
- Args:
- in_h: Input height.
- in_w: Input width.
- out_h: Output height.
- out_w: Output width.
- min_div: Length that output dimensions should be divisible by.
- """
- if out_h is None:
- out_h = (in_h + min_div - 1) // min_div * min_div
- if out_w is None:
- out_w = (in_w + min_div - 1) // min_div * min_div
-
- if out_h % min_div != 0 or out_w % min_div != 0:
- raise ValueError(
- f"Padded output height and width are not divisible by min_div={min_div}."
- )
-
- left = (out_w - in_w) // 2
- right = out_w - in_w - left
- top = (out_h - in_h) // 2
- bottom = out_h - in_h - top
-
- pad = (left, right, top, bottom)
- unpad = (-left, -right, -top, -bottom)
-
- return pad, unpad
-
- def pad(x, p=2**6):
- h, w = x.size(2), x.size(3)
- pad, _ = compute_padding(h, w, min_div=p)
- return F.pad(x, pad, mode="constant", value=0)
-
-
- # evaluate
- def compute_psnr(img1, img2):
- img1 = img1.cpu().numpy()
- img2 = img2.cpu().numpy()
-
- mse = np.mean( (img1 - img2) ** 2 )
- if mse < 1.0e-10:
- return 100
- PIXEL_MAX = 1
- return np.array([20 * math.log10(PIXEL_MAX / math.sqrt(mse))])
-
- def gaussian(window_size, sigma):
-
- gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
- return gauss/gauss.sum()
-
- # 创建高斯核,通过两个一维高斯分布向量进行矩阵乘法得到
- # 可以设定channel参数拓展为3通道
- def create_window(window_size, channel=1):
-
- _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
- _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
- window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
- return window
-
- def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
- # import ipdb
- # ipdb.set_trace()
- # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
- if val_range is None:
- if torch.max(img1) > 128:
- max_val = 255
- else:
- max_val = 1
-
- if torch.min(img1) < -0.5:
- min_val = -1
- else:
- min_val = 0
- L = max_val - min_val
- else:
- L = val_range
-
- padd = 0
- (_, channel, height, width) = img1.size()
- if window is None:
- real_size = min(window_size, height, width)
- window = create_window(real_size, channel=channel).to(img1.device)
-
- mu1 = F.conv2d(img1, window, padding=padd, groups=channel) # 高斯滤波 求均值
- mu2 = F.conv2d(img2, window, padding=padd, groups=channel) # 求均值
-
- mu1_sq = mu1.pow(2) # 平方
- mu2_sq = mu2.pow(2)
- mu1_mu2 = mu1 * mu2
-
- sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq # var(x) = Var(X)=E[X^2]-E[X]^2
- sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
- sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 # 协方差
-
- C1 = (0.01 * L) ** 2
- C2 = (0.03 * L) ** 2
-
- v1 = 2.0 * sigma12 + C2
- v2 = sigma1_sq + sigma2_sq + C2
- cs = torch.mean(v1 / v2) # contrast sensitivity
-
- ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
-
- if size_average:
- ret = ssim_map.mean()
- else:
- ret = ssim_map.mean(1).mean(1).mean(1)
-
- if full:
- return ret, cs
- return ret
-
- def compute_msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
-
- device = img1.device
- weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
- # weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
- levels = weights.size()[0]
- mssim = []
- mcs = []
- for _ in range(levels):
- sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
- #print("sim",sim)
- mssim.append(sim)
- mcs.append(cs)
-
- img1 = F.avg_pool2d(img1, (2, 2))
- img2 = F.avg_pool2d(img2, (2, 2))
-
- mssim = torch.stack(mssim)
- mcs = torch.stack(mcs)
-
- # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
- if normalize:
- mssim = (mssim + 1) / 2
- mcs = (mcs + 1) / 2
-
- pow1 = mcs ** weights
- pow2 = mssim ** weights
- # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
- output = torch.prod(pow1[:-1] * pow2[-1]) #返回所有元素的乘积
- return output.cpu().numpy()
-
- def compute_bpp(out_net):
- size = out_net['x_hat'].size()
- num_pixels = size[0] * size[2] * size[3]
- return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
- for likelihoods in out_net['likelihoods'].values()).item()
-
- # print(f'torch PSNR: {compute_psnr(t_x, out_net["x_hat"]):.2f}dB')
- # print(f'torch MS-SSIM: {compute_msssim(t_x, out_net["x_hat"]):.4f}')
- # print(f'torch Bit-rate: {compute_bpp(out_net):.3f} bpp')
-
-
- if __name__ == '__main__':
- choose = 1
- exp_name = ['bmshj2018-factorized', 'bmshj2018-hyperprior'][choose]
- ckpt_name = 'bmshj2018-factorized-prior' if exp_name == 'bmshj2018-factorized' else 'bmshj2018-hyperprior'
-
- import argparse, os, glob
- import pandas as pd
- import time
-
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument(
- "--input", default='example.png', type=str,
- help="Input filename.")
- parser.add_argument(
- "--qp", default=2, type=int,
- help="Quality parameter, choose from [1~7] (model0) or [1~8] (model1)"
- )
- parser.add_argument(
- "--model_type", default=0, type=int
- )
-
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
- args = parser.parse_args()
-
- if not os.path.exists('test_output'):
- os.makedirs('test_output')
- qp_list = [1,2,5,7]
- imglist = glob.glob("/userhome/CAE-ADMM/kodak/test/*.png")
- assert imglist != []
- for index, qp in enumerate(qp_list):
- out_dir = 'test_output/qp' + str(qp)
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- args.qp = qp
-
- net_name = exp_name.replace('-', '_')
- t_net = locals()[net_name](quality=qp, pretrained=False).to(device)
- dict_route = 'compressai/models/%s-%d.pth.tar' % (ckpt_name, qp)
- print('load ckpt_name:', dict_route)
- torch_param_dict=torch.load(dict_route)
- t_net.load_state_dict(torch_param_dict, strict=False)
-
- t_net.eval()
-
- PSNR_all = np.array([])
- bpp_all = np.array([])
- MSSSIM_all = np.array([])
-
- start_time = time.time()
- for i, oneimg in enumerate(imglist):
- img = Image.open(oneimg).convert('RGB')
- t_x = transforms.ToTensor()(img).unsqueeze(0).to(device)
-
- t_x = pad(t_x, p=64)
-
- # forward
- with torch.no_grad():
- out_net = t_net.forward(t_x)
- out_net['x_hat'].clamp_(0, 1)
- #print(out_net.keys())
-
- test_psnr = compute_psnr(t_x, out_net["x_hat"])
- #print('------------------>', PSNR_all.shape, test_psnr.shape)
- PSNR_all = np.concatenate((PSNR_all, test_psnr), axis=0)
- test_bpp =compute_bpp(out_net)
- bpp_all = np.concatenate((bpp_all, [test_bpp]), axis=0)
- ms_ssim = compute_msssim(t_x, out_net["x_hat"])
- #print(type(ms_ssim), ms_ssim.shape)
- ms_ssim = np.array([ms_ssim.item()])
- MSSSIM_all = np.concatenate((MSSSIM_all, ms_ssim), axis=0)
- # if i>1:break
- end_time = time.time()
- time_cost = end_time - start_time
-
- PSNR_all = np.concatenate((PSNR_all, [np.mean(PSNR_all)]), axis=0)
- PSNR_all = PSNR_all.reshape(-1, 1).mean()
- bpp_all = np.concatenate((bpp_all, [np.mean(bpp_all)]), axis=0)
- bpp_all = bpp_all.reshape(-1, 1).mean()
- MSSSIM_all = np.concatenate((MSSSIM_all, [np.mean(MSSSIM_all)]), axis=0)
- MSSSIM_all = MSSSIM_all.reshape(-1, 1).mean()
-
- mode = 'w'
- header = True
- if index != 0:
- mode = 'a'
- header=False
-
- all_results = [{'qp':qp, 'bpp':bpp_all, 'PSNR':PSNR_all, 'MSSSIM':MSSSIM_all, 'time_cost':time_cost, 'GPU M':1822}]
- print(all_results)
- results2 = pd.DataFrame(data=all_results)
- results2.to_csv(f'test_output/%s_pytorch.csv' % exp_name,index=False, mode=mode, header=header)
|