|
- import os
- import os.path as osp
- from glob import glob
- import pandas as pd
- from collections import defaultdict
- import numpy as np
- import torch
- from tqdm import tqdm
- from torch import nn
- import torchvision
- from PIL import Image
- import time
- import torchvision.transforms.functional as TF
- import torch.nn.functional as F
- from torch.utils.data import Dataset, DataLoader
- from torch.nn.utils import clip_grad_norm_
- from torchvision.transforms import InterpolationMode
- import argparse
- from torchvision.utils import make_grid
- from torch.autograd import Variable
- import scipy.signal as signal
- import matplotlib
- import matplotlib.cm as cm
-
- from end2end_compression_model_version14 import CompressionModel
-
- import pytorch_msssim
- from torch.utils.tensorboard import SummaryWriter
- import warnings
-
- warnings.simplefilter("ignore")
-
- torch.autograd.set_detect_anomaly(True)
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def remove_noise(depth):
- '''
- depth: [1, H, W]
- 去掉depth中的噪点
- '''
- C, H, W= depth.shape
- img = depth[0]
- # mask是我们判别为噪声或空值的区域,需要进行处理
- mask = ( img[ :]==0)
- filter_img = signal.medfilt2d(input=img, kernel_size=3)
- modified_img = mask*filter_img + (1-mask)*img # [H, W]
- return modified_img.reshape(C, H, W).astype('float32')
-
- def smooth_noise(depth):
- '''
- depth: [1, H, W]
- 采用中值滤波器处理depth中的噪点
- '''
- C, H, W= depth.shape
- img = depth[0]
- filter_img = signal.medfilt2d(input=img, kernel_size=3)
- return filter_img.reshape(C, H, W).astype('float32')
-
- def colorize(tensor, vmin=0, vmax=0.4, cmap="turbo"):
- assert tensor.ndim == 2
- normalizer = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
- mapper = cm.ScalarMappable(norm=normalizer, cmap=cmap)
- tensor = mapper.to_rgba(tensor)[..., :3]
- return tensor
-
- def save_img(writer, tensor, tag, step, vmin=0, vmax =0.4, color=True):
- '''
- save images to tensorboard
- '''
- grid = make_grid(tensor.detach(), nrow=1)
- grid = grid.cpu().numpy() # CHW
- if color:
- grid = grid[0] # HW
- grid = colorize(grid, vmin, vmax).transpose(2, 0, 1) # CHW
- writer.add_image(tag, grid, step)
-
-
- class MyDataset(Dataset):
- def __init__(self, label = 'train', scale = 1):
- super(MyDataset, self).__init__()
- self.depth_list = []
- self.mask_list = []
- self.num = 0
- self.scale = scale
- self.label = label
-
- if label == 'train':
- indoors_data = pd.read_csv('./data_list/train_indoors.csv',header=None)
- depth_df_1 = indoors_data.iloc[:,1]
- mask_df_1 = indoors_data.iloc[:,2]
- for i in range(len(depth_df_1)):
- if i%20!=0:
- self.depth_list.append(depth_df_1[i])
- self.mask_list.append(mask_df_1[i])
- self.num = self.num +1
-
- outdoor_data = pd.read_csv('./data_list/train_outdoor.csv',header=None)
- depth_df_2 = outdoor_data.iloc[:,1]
- mask_df_2 = outdoor_data.iloc[:,2]
-
- for i in range(len(depth_df_2)):
- if i%20!=0:
- self.depth_list.append(depth_df_2[i])
- self.mask_list.append(mask_df_2[i])
- self.num = self.num +1
-
- if label == 'val':
- indoors_data = pd.read_csv('./data_list/train_indoors.csv',header=None)
- depth_df_1 = indoors_data.iloc[:,1]
- mask_df_1 = indoors_data.iloc[:,2]
- for i in range(len(depth_df_1)):
- if i%20==0:
- self.depth_list.append(depth_df_1[i])
- self.mask_list.append(mask_df_1[i])
- self.num = self.num +1
-
- outdoor_data = pd.read_csv('./data_list/train_outdoor.csv',header=None)
- depth_df_2 = outdoor_data.iloc[:,1]
- mask_df_2 = outdoor_data.iloc[:,2]
-
- for i in range(len(depth_df_2)):
- if i%20==0:
- self.depth_list.append(depth_df_2[i])
- self.mask_list.append(mask_df_2[i])
- self.num = self.num +1
-
-
- if label == 'test':
- indoors_data = pd.read_csv('./data_list/val_indoors.csv',header=None)
- depth_df_1 = indoors_data.iloc[:,1]
- mask_df_1 = indoors_data.iloc[:,2]
-
- for i in range(len(depth_df_1)):
- self.depth_list.append(depth_df_1[i])
- self.mask_list.append(mask_df_1[i])
- self.num = self.num +1
-
- outdoor_data = pd.read_csv('./data_list/val_outdoor.csv',header=None)
- depth_df_2 = outdoor_data.iloc[:,1]
- mask_df_2 = outdoor_data.iloc[:,2]
-
- for i in range(len(depth_df_2)):
- self.depth_list.append(depth_df_2[i])
- self.mask_list.append(mask_df_2[i])
- self.num = self.num +1
-
-
- print('finish loading', label, 'dataset:', self.num)
-
- def __len__(self):
- return self.num
-
- def __getitem__(self, idx):
- depth_path = self.depth_list[idx]
- mask_path = self.mask_list[idx]
-
- depth = np.load(depth_path).astype(np.float32).transpose(2, 0, 1) /self.scale # [1, H, W]
-
- depth_1 = ((np.floor(self.scale*1000*depth)%512)/(self.scale*1000)).astype(np.float32)
- depth_2 = ((np.floor(self.scale *1000*depth/512))/(self.scale *1000)).astype(np.float32)
-
- mask = np.load(mask_path).astype(np.float32) # [H, W]
- H, W = mask.shape
- mask = mask.reshape(1, H, W) # [1, H, W]
-
- # 去掉mask为0的区域
- depth_1[mask==0] =0
- depth_2[mask==0] =0
-
- img = np.zeros((2, H, W))
- img[0:1, :, :] = depth_1
- img[1:2, :, :] = depth_2
- img = img.astype(np.float32)
-
- img_tensor = torch.from_numpy(img)
-
- x = np.random.randint(0, img_tensor.shape[1] - 64)
- y = np.random.randint(0, img_tensor.shape[2] - 256)
-
- img = img_tensor[:, x:x+64, y:y+256]
-
- return img
-
-
- if __name__ == "__main__":
- # 本代码用于训练深度无损熵编码模型
- parser = argparse.ArgumentParser()
- parser.add_argument("--batch_size", type=int, default=16, help="Size of the batches")
- parser.add_argument("--epoch_num", type=int, default=200, help="Size of the epoch")
-
- parser.add_argument("--epoch_decay", type=int, default=20, help="Size of the epoch for every decay")
- parser.add_argument("--lr_decay", type=float, default=0.75, help="lr decay")
-
- parser.add_argument("--optimizer", type=str, default='Adam', help="SGD or Adam")
-
- parser.add_argument("--n_channel_1", type=int, default=64, help="Channel numbers of depth_model")
- parser.add_argument("--n_channel_3", type=int, default=64, help="Channel numbers of residual_model")
-
- parser.add_argument("--learning_rate", type=float, default=0.0001, help="Learning rate during training")
-
- parser.add_argument("--resume_1", type=str, default=None, \
- help="Checkpoint path to resume CompressionModel")
-
- parser.add_argument("--loss_lambda", type=float, default=100, help="Lambda value in loss function")
- parser.add_argument("--loss_beta", type=float, default=0, help="Beta value in loss function")
- parser.add_argument("--loss_gamma", type=float, default=0, help="Gamma value in loss function")
-
- parser.add_argument("--loss_lambda_res_esti_1", type=float, default=100, help="Lambda value in loss function")
- parser.add_argument("--loss_beta_res_esti_1", type=float, default=0, help="Beta value in loss function")
- parser.add_argument("--loss_gamma_res_esti_1", type=float, default=0, help="Gamma value in loss function")
-
- parser.add_argument("--qstep_residual", type=float, default=1000, help="qstep in residual loss function")
- parser.add_argument("--store_path", type=str, default='./model/', help="Dir path to save trained model")
-
- parser.add_argument("--scale", type=float, default=100, help="Scale to modify the depth")
-
- opt = parser.parse_args()
- batch_size = opt.batch_size
- epoch_num = opt.epoch_num
- optimizer = opt.optimizer
-
- epoch_decay = opt.epoch_decay
- lr_decay = opt.lr_decay
-
- n_channel_1 = opt.n_channel_1
-
- n_channel_3 = opt.n_channel_3
-
-
- qstep_residual = opt.qstep_residual
-
- loss_lambda = opt.loss_lambda
- loss_beta = opt.loss_beta
- loss_gamma = opt.loss_gamma
-
- loss_lambda_res_esti_1 = opt.loss_lambda_res_esti_1
- loss_beta_res_esti_1 = opt.loss_beta_res_esti_1
- loss_gamma_res_esti_1 = opt.loss_gamma_res_esti_1
-
- resume_1 = opt.resume_1
-
- scale = opt.scale
-
- store_path = opt.store_path + str(time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()))
- if not os.path.exists(store_path):
- os.makedirs(store_path, exist_ok=True)
- print('create dir', store_path)
-
- print(opt)
-
-
- argsDict = opt.__dict__
- log_path = store_path + '/log_train.txt'
- if not os.path.exists(log_path):
- os.system(r"touch {}".format(log_path))
-
- with open(log_path,'a') as f:
- f.writelines('------------------ start ------------------' + '\n')
- for eachArg, value in argsDict.items():
- f.writelines(eachArg + ' : ' + str(value) + '\n')
- f.writelines('------------------- end -------------------'+ '\n')
-
- train_data = MyDataset(label='train', scale=scale)
- train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,num_workers=16)
-
- val_data = MyDataset(label='val', scale=scale)
- val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True,num_workers=16)
-
- loss_mse = nn.MSELoss()
- loss_l1 = nn.L1Loss()
- loss_msssim = pytorch_msssim.MSSSIM(normalize=True).to(device)
-
- # 可视化
- writer = SummaryWriter()
- LR = opt.learning_rate
-
- if resume_1!=None:
- print('resume_1: ', resume_1)
- compression_model = torch.load(resume_1, map_location=torch.device('cpu')).to(device)
- else:
- compression_model = CompressionModel(depth_N=n_channel_1, residual_N=n_channel_3, \
- scale=scale, qstep_residual=qstep_residual).to(device)
-
- if optimizer=='Adam':
- optimizer_1 = torch.optim.Adam(compression_model.parameters(),lr=LR)
- if optimizer == 'SGD':
- optimizer_1 = torch.optim.SGD(compression_model.parameters(), lr = LR, momentum=0.9)
-
- step_count=0
-
- lr_scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer_1, epoch_decay, gamma=lr_decay)
-
- for epoch in range(epoch_num):
-
- with open(log_path,'a') as f:
- f.writelines('LR: %.8f\n' %(lr_scheduler.get_lr()[0]))
-
- currentBPP = 0
-
- ######################## train ####################################
-
- for step, batch_data in enumerate(train_loader):
- # batch_x [batchsize, 2, 64, 256]
- batch_x = batch_data
-
- step_count += 1
-
- num_pixels = batch_x.size()[0]*batch_x.size()[2]*batch_x.size()[3]
- batch_x = Variable(batch_x, requires_grad=False).to(device)
- batch_x_ori = batch_x.clone()
-
- optimizer_1.zero_grad()
- compression_model.train()
- lossyDepth, residual, residual_esti_1, xp1, xp2, xp_residual_1 = compression_model(batch_x,True)
-
- mse_loss = loss_mse(batch_x, lossyDepth)
- msssim_loss = 1 - loss_msssim(batch_x, lossyDepth)
- l1_loss = loss_l1(batch_x, lossyDepth)
-
- RECloss = loss_lambda* mse_loss + loss_beta*msssim_loss + loss_gamma*l1_loss
-
- mse_loss_res_esti_1 = loss_mse(residual, residual_esti_1)
- msssim_loss_res_esti_1 = 1 - loss_msssim(residual, residual_esti_1)
- l1_loss_res_esti_1 = loss_l1(residual, residual_esti_1)
- RECloss_res_esti_1 = loss_lambda_res_esti_1* mse_loss_res_esti_1 + loss_beta_res_esti_1*msssim_loss_res_esti_1 + loss_gamma_res_esti_1*l1_loss_res_esti_1
-
-
- train_bpp1 = torch.sum(torch.log(xp1)) / (-np.log(2) * num_pixels)
- train_bpp2 = torch.sum(torch.log(xp2)) / (-np.log(2) * num_pixels)
- train_bpp1_residual = torch.sum(torch.log(xp_residual_1)) / (-np.log(2) * num_pixels)
-
- l_rec = train_bpp1 + train_bpp2 + train_bpp1_residual + \
- (RECloss + RECloss_res_esti_1)
-
- l_rec.backward()
- orig_grad_norm = clip_grad_norm_(compression_model.parameters(), 10)
- optimizer_1.step()
-
- residual_max = residual.max()
- residual_min = residual.min()
-
- print('train epoch: %d step: %d total_loss: %.4f mse_loss: %.8f msssim_loss: %.4f l1_loss: %.4f mse_loss_res_esti_1: %.8f msssim_loss_res_esti_1: %.4f l1_loss_res_esti_1: %.4f e1_loss: %.4f e2_loss: %.4f e1_loss_residual: %.4f r_max: %.4f, r_min: %.4f '
- %(epoch, step, l_rec.item(), mse_loss.item(), msssim_loss.item(), l1_loss.item(), mse_loss_res_esti_1.item(), msssim_loss_res_esti_1.item(), l1_loss_res_esti_1.item(),train_bpp1.item(),\
- train_bpp2.item(), train_bpp1_residual.item(), residual_max, residual_min))
- with open(log_path,'a') as f:
- f.writelines('train epoch: %d step: %d total_loss: %.4f mse_loss: %.8f msssim_loss: %.4f l1_loss: %.4f mse_loss_res_esti_1: %.8f msssim_loss_res_esti_1: %.4f l1_loss_res_esti_1: %.4f e1_loss: %.4f e2_loss: %.4f e1_loss_residual: %.4f r_max: %.4f, r_min: %.4f \n'
- %(epoch, step, l_rec.item(),mse_loss.item(), msssim_loss.item(), l1_loss.item(), mse_loss_res_esti_1.item(), msssim_loss_res_esti_1.item(), l1_loss_res_esti_1.item(),train_bpp1.item(),train_bpp2.item(), \
- train_bpp1_residual.item(), residual_max, residual_min))
-
-
- writer.add_scalar('train/entropy_bpp', train_bpp1.item() + train_bpp2.item() + train_bpp1_residual.item() , step_count)
-
- lr_scheduler.step()
-
- ########################### validate ##############################
-
- val_steps = 0
- for step, batch_data in enumerate(val_loader):
-
- batch_x = batch_data
- compression_model.eval()
-
- num_pixels = batch_x.size()[0]*batch_x.size()[2]*batch_x.size()[3]
- batch_x = Variable(batch_x, requires_grad=False).to(device)
-
- batch_x_ori = batch_x.clone()
-
- with torch.no_grad():
-
- lossyDepth, residual, residual_esti_1, xp1, xp2, xp_residual_1 = compression_model(batch_x,False)
-
- residual_max = residual.max()
- residual_min = residual.min()
-
- val_bpp1 = torch.sum(torch.log(xp1)) / (-np.log(2) * num_pixels)
- val_bpp2 = torch.sum(torch.log(xp2)) / (-np.log(2) * num_pixels)
- val_bpp1_residual = torch.sum(torch.log(xp_residual_1)) / (-np.log(2) * num_pixels)
-
-
- l_rec = val_bpp1 +val_bpp2 + val_bpp1_residual
-
- print('val epoch: %d step: %d e1_loss: %.4f e2_loss: %.4f e1_loss_residual: %.4f r_max: %.4f, r_min: %.4f '
- %(epoch, step,val_bpp1.item(),val_bpp2.item(), val_bpp1_residual.item(), residual_max, residual_min))
- with open(log_path,'a') as f:
- f.writelines('val epoch: %d step: %d e1_loss: %.4f e2_loss: %.4f e1_loss_residual: %.4f r_max: %.4f, r_min: %.4f \n'
- %(epoch, step,val_bpp1.item(),val_bpp2.item(), val_bpp1_residual.item(), residual_max, residual_min))
-
- writer.add_scalar('val/entropy_bpp', val_bpp1.item() + val_bpp2.item() + val_bpp1_residual.item() , step_count)
-
- currentBPP = currentBPP + val_bpp1.item() + val_bpp2.item() + val_bpp1_residual.item()
- val_steps = val_steps +1
-
- currentBPP = currentBPP/val_steps
-
- torch.save(compression_model, store_path+'/compression_model_%d_%.8f.pkl' % \
- (epoch, currentBPP))
-
- writer.close()
|