|
- from model_baseline import *
-
- import argparse
- import math
- import random
- import shutil
- import os.path as osp
- import glob
- import os
- import sys
- sys.path.append("..")
- import torch
- import torch.optim as optim
- import torch.nn as nn
- from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
- from torch.autograd import Variable
-
- from torch.utils.data import DataLoader,Dataset
-
- from torchvision import transforms
-
- from compressai.datasets import ImageFolder
- from compressai.layers import GDN
- from compressai.models import CompressionModel
- from compressai.models.utils import conv, deconv
- import time
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- import cv2
-
- dir = '/gpfs/userhome/zhm/CompressAIDemo/CompressAI/demo/'
- out_root_path = dir + 'decoded_files/'
- if not os.path.exists(out_root_path):
- print("not ex")
- os.system("mkdir "+out_root_path)
-
-
- #file
- out_root_path_file = open(osp.join(out_root_path,"details.txt"),'w')
- save_path = '/gpfs/userhome/zhm/CompressAIDemo/CompressAI/demo/model/'
-
- class MyDataset(Dataset):
- def __init__(self, input_path):
- super(MyDataset, self).__init__()
- self.input_list = []
- self.name_list = []
- self.num = 0
- for _ in range(1):
- for i in os.listdir(input_path):
- input_img = input_path + i
- input_name = i
- self.input_list.append(input_img)
- self.name_list.append(input_name)
- self.num = self.num + 1
-
- def __len__(self):
- return self.num
-
- def __getitem__(self, idx):
- img = np.array(cv2.imread(self.input_list[idx]))
- name = self.name_list[idx]
- input_np = img.astype(np.float32).transpose(2, 0, 1) / 255.0
- input_tensor = torch.from_numpy(input_np)
-
- return input_tensor, name
-
- def mse2psnr(mse):
- # 根据Hyper论文中的内容,将MSE->psnr(db)
- # return 10*math.log10(255*255/mse)
- return 10 * math.log10(1/ mse) #???
- #psnr calculate
- def psnr(img1, img2):
- mse = np.mean( (img1/255. - img2/255.) ** 2 )
- if mse < 1.0e-10:
- return 100
- PIXEL_MAX = 1
- return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
-
- ################################################################
- class RateDistortionLoss(nn.Module):
- """Custom rate distortion loss with a Lagrangian parameter."""
- def __init__(self, lmbda=1e-2):
- super().__init__()
- self.mse = nn.MSELoss()
- self.lmbda = lmbda
-
- def forward(self, output, target):
- N, _, H, W = target.size()
- out = {}
- num_pixels = N * H * W
-
- ########################################
- # 计算误差
- out['mse_loss'] = self.mse(output['x_hat'], target) #end to end
- out['ms_ssim'] = ms_ssim(output['x_hat'], target, data_range=1, size_average=False)[0] # (N,)
-
- ########################################
- out['psnr'] = mse2psnr(self.mse(output['x_hat'], target))
- return out
-
-
- class AverageMeter:
- """Compute running average."""
- def __init__(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def save_pic(data,h,w,path):
- if osp.exists(path):
- os.system("rm "+path)
- print("rm "+path)
- '''
- reimage = data.cpu().clone()
- reimage[reimage > 1.0] = 1.0
- reimage = reimage.squeeze(0)
- reimage = transforms.ToPILImage()(reimage) # PIL格式
- reimage.save(path)
- '''
- img = data[:h, :w, :]
- cv2.imwrite(path,img)
-
- def test_epoch(epoch, test_dataloader, model, criterion):
- global out_root_path_file
-
- model.eval()
- device = next(model.parameters()).device
- loss = AverageMeter()
- psnr = AverageMeter()
- msssim = AverageMeter()
- enctime = AverageMeter()
- dectime = AverageMeter()
- bpp_estimate = AverageMeter()
-
- with torch.no_grad():
- for d in test_dataloader:
- # for ii in range(len(d)):
- # print(d[ii])
- # raise ValueError("stop")
- d1 = d[0].to(device)
- #====================
- #codec
- print("start codec")
- name = d[1][0]
- print(name)
-
-
- out={}
- N, _, H, W = d1.size()
- out = {}
- num_pixels = N * H * W
-
- hh, ww = d1.shape[2], d1.shape[3]
- pp = 64 # maximum 6 strides of 2
- new_hh = (hh + pp - 1) // pp * pp
- new_ww = (ww + pp - 1) // pp * pp
- padding_left = (new_ww - ww) // 2
- padding_right = new_ww - ww - padding_left
- padding_top = (new_hh - hh) // 2
- padding_bottom = new_hh - hh - padding_top
-
- d1_padded = F.pad(
- d1,
- (padding_left, padding_right, padding_top, padding_bottom),
- mode="constant",
- value=0,
- )
-
- out_net = model(d1_padded)
- # print("network z1hat",out_net['z1_hat'])
- output = out_net
- output["x_hat"] = F.pad(
- out_net["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)
- )
-
- out['bpp'] = (torch.log(output['likelihoods']['y']).sum() / (-math.log(2) * num_pixels))
- # print("bpp_loss",out['bpp'].item())
- '''
- ###################################################################################################################
- #encode
- out_net_en = model.compress(d1_padded,name,output_path=out_root_path,device=device)
- print("bpp",out_net_en['bpp_real'])
- print("=============enc ok==========")
- # raise ValueError("enc ok!")
- #decode
- out_net = model.decompress(d1_padded,name,output_path=out_root_path,device=device)
- print("=============dec ok==========")
- out_net["x_hat"] = F.pad(
- out_net["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)
- )
- # print(out_net_en)
- # print(out_net)
- # print("compare:",bool(out_net_en==out_net))
- # raise ValueError("stop")
- #====================
- '''
- out_criterion = criterion(out_net, d1)
- psnr.update(out_criterion['psnr'])
- msssim.update(out_criterion['ms_ssim'])
- bpp_estimate.update(out['bpp'].item())
-
-
- psnr_val = out_criterion['psnr']
- msssim_val = out_criterion['ms_ssim']
- bpp_estimate_val = out['bpp'].item()
-
-
- print_context = (name +
- f'\tPSNR (dB): {psnr_val :.3f} |'
- f'\tMS-SSIM: {msssim_val :.4f} |'
- f'\tEstimate_bpp: {bpp_estimate_val:.3f} \n'
- )
-
- out_root_path_file.write(print_context)
- print(print_context)
- ## post
- rec_d1 = output["x_hat"]
- rec_d1 = torch.clamp(rec_d1, min=0, max=1.0)
- rec_d1 = rec_d1.data[0].cpu().detach().numpy()
- rec_d1 = rec_d1.transpose(1, 2, 0) * 255.0
- rec_d1 = rec_d1.astype('uint8')
- '''
- ori_d1 = d1
- ori_d1 = torch.clamp(ori_d1, min=0, max=1.0)
- ori_d1 = ori_d1.data[0].cpu().detach().numpy()
- ori_d1 = ori_d1.transpose(1, 2, 0) * 255.0
- ori_d1 = ori_d1.astype('uint8')
- '''
- ##save pic
- save_pic(rec_d1, H, W, out_root_path + name)
- # save_pic(ori_d1, H, W, out_root_path + "test.png")
- ####
-
- out_root_path_file.close()
- print(f'Test epoch {epoch}: Average losses:'
- f'\tTime: {time.strftime("%Y-%m-%d %H:%M:%S")} |'
- f'\tEstimate_bpp: {bpp_estimate.avg:.3f} |'
- f'\tMS-SSIM: {msssim.avg :.4f} |'
- f'\tPSNR (dB): {psnr.avg :.3f} \n' # 平均一张图的PSNR
- )
-
- return loss.avg
-
-
- def parse_args(argv):
- parser = argparse.ArgumentParser(description='Example training script')
- # yapf: disable
- parser.add_argument(
- '-d',
- '--dataset',
- type=str,
- help='Training dataset')
- parser.add_argument(
- '-e',
- '--epochs',
- default=1000,
- type=int,
- help='Number of epochs (default: %(default)s)')
- parser.add_argument(
- '-lr',
- '--learning-rate',
- default=1e-4,
- type=float,
- help='Learning rate (default: %(default)s)')
- parser.add_argument(
- '-n',
- '--num-workers',
- type=int,
- default=16,
- help='Dataloaders threads (default: %(default)s)')
- parser.add_argument(
- '--lambda',
- dest='lmbda',
- type=float,
- default=1e-2,
- # default=0.0018,
- help='Bit-rate distortion parameter (default: %(default)s)')
- parser.add_argument(
- '--batch-size',
- type=int,
- default=16,
- help='Batch size (default: %(default)s)')
- parser.add_argument(
- '--test-batch-size',
- type=int,
- default=64,
- help='Test batch size (default: %(default)s)')
- parser.add_argument(
- '--aux-learning-rate',
- default=1e-3,
- help='Auxiliary loss learning rate (default: %(default)s)')
- parser.add_argument(
- '--patch-size',
- type=int,
- nargs=2,
- default=(256, 256),
- help='Size of the patches to be cropped (default: %(default)s)')
- parser.add_argument(
- '--cuda',
- type=int,
- default=0,
- help='Use cuda')
- parser.add_argument(
- '--pretrained',
- type=bool,
- default=True,
- dest='pretrained',
- help='if load from pretrained')
- parser.add_argument(
- '--logfile',
- type=str,
- default="train_log.txt",
- help='logfile_name')
- parser.add_argument(
- '--seed',
- type=float,
- help='Set random seed for reproducibility')
- # yapf: enable
- args = parser.parse_args(argv)
- return args
-
- def main(argv):
- args = parse_args(argv)
-
- if args.seed is not None:
- torch.manual_seed(args.seed)
- random.seed(args.seed)
-
- # if args.seed is not None:
- # torch.manual_seed(args.seed)
- # random.seed(args.seed)
-
- # train_transforms = transforms.Compose(
- # [transforms.RandomCrop(args.patch_size),
- # transforms.ToTensor()])
- #
- # test_transforms = transforms.Compose(
- # [transforms.CenterCrop(args.patch_size),
- # transforms.ToTensor()])
- train_transforms = transforms.Compose(
- [transforms.ToTensor()])
-
- test_transforms = transforms.Compose(
- [transforms.ToTensor()])
-
- train_dataset = MyDataset(input_path=args.dataset + 'train/')
- test_dataset = MyDataset(input_path=args.dataset + 'valid/')
- train_dataloader = DataLoader(train_dataset,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- shuffle=True,
- pin_memory=False)
-
- test_dataloader = DataLoader(test_dataset,
- batch_size=args.test_batch_size,
- num_workers=args.num_workers,
- shuffle=False,
- pin_memory=False)
-
-
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
- device = 'cuda' if (torch.cuda.is_available() and args.cuda != -1) else 'cpu'
- print(device)
- if device=='cuda':
- torch.cuda.set_device(args.cuda)
- ##去随机--2021.10.29
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- print('temp gpu device number:')
- print(torch.cuda.current_device())
-
- net = FactorizedPrior(192,192)
-
- #加载最新模型继续训练
- if args.pretrained:
- model = torch.load(save_path + 'checkpoint_best_loss.pth.tar', map_location=lambda storage, loc: storage)
- model.keys()
- # net.load_state_dict(torch.load('path/params.pkl'))
- net.load_state_dict(model['state_dict'])
- # 严格加载
- # net.load_state_dict(model['state_dict'])
- print("load model ok")
- else:
- print("train from none")
-
-
- net.entropy_bottleneck.update()
- net = net.to(device)
- optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)
- aux_optimizer = optim.Adam(net.aux_parameters(), lr=args.aux_learning_rate)
- print("lambda:",args.lmbda)
- criterion = RateDistortionLoss(lmbda=args.lmbda)
-
- best_loss = 1e10
- # for epoch in range(args.epochs):
- # train_epoch(epoch, train_dataloader, net, criterion, optimizer,
- # aux_optimizer,log_file=args.logfile)
- for epoch in [0]: # 只跑一次
- # try:
- #验证集
- loss = test_epoch(epoch, test_dataloader, net, criterion)
-
- # is_best = loss < best_loss
- # best_loss = min(loss, best_loss)
- # if args.save:
- # save_checkpoint(
- # {
- # 'epoch': epoch + 1,
- # 'state_dict': net.state_dict(),
- # 'loss': loss,
- # 'optimizer': optimizer.state_dict(),
- # 'aux_optimizer': aux_optimizer.state_dict(),
- # }, is_best)
- # except:
- # print("val error")
- # if args.save:
- # state = {
- # 'epoch': epoch + 1,
- # 'state_dict': net.state_dict(),
- # 'loss': 'none',
- # 'optimizer': optimizer.state_dict(),
- # 'aux_optimizer': aux_optimizer.state_dict(),
- # }
- # torch.save(state, 'checkpoint.pth.tar')
-
- if __name__ == '__main__':
- main(sys.argv[1:])
|