|
- from torch.autograd import Variable
- from utils.img_utils import *
- from networks.rnet import RNet
- from networks.tnet import TNet
-
- import argparse
- import yaml
- import os
- import cv2
- import torch
- import numpy as np
- import os.path as osp
-
- '''
- 利用全景图像的反射场景信息辅助反射消除。
- '''
-
- class Pano_RR():
- def __init__(self, *arg, **kwarg):
- '''
- 初始化算法模型。
- 配置文件config.yaml保存模型地址(model_path),以及输入图像的长和宽(width, height)
- '''
- self.rnet = RNet().cuda()
- self.tnet = TNet().cuda()
- self.config_file = open('./config.yaml', 'rb')
- self.config = yaml.load(self.config_file)
-
- self.rnet_path = self.config['rnet_path']
- self.tnet_path = self.config['tnet_path']
- self.w = int(self.config['input']['width'])
- self.h = int(self.config['input']['height'])
-
- def __call__(self, m, rs):
- '''
- 利用全景图像的反射场景信息辅助反射消除。
- Arg
- ---
- m: color image in RGB format
- shape: (H, W, 3)
- type: np.uint8
-
- rs: color image in RGB format
- shape: (H, W, 3)
- type: np.uint8
- Return
- ---
- outT: color image in RGB format
- shape: (H, W, 3)
- type: np.uint8
- '''
-
- rnet = torch.nn.DataParallel(self.rnet)
- checkpoint = torch.load(self.rnet_path)
- rnet.load_state_dict(checkpoint['state_dict'])
- rnet.eval()
-
- m = cv2.resize(m, (self.w, self.h))
- rs = cv2.resize(rs, (self.w, self.h))
- rs = photometric(rs)
-
- m, rs = self.var_process(m), self.var_process(rs)
-
- _, outR = rnet(m, rs)
- outR = MatrixToImage(outR.data.cpu().numpy().reshape(3, self.h, self.w).transpose(1, 2, 0))
-
- tnet = torch.nn.DataParallel(self.tnet)
- checkpoint = torch.load(self.tnet_path)
- tnet.load_state_dict(checkpoint['state_dict'])
- tnet.eval()
-
- outR = self.var_process(outR)
-
- outT = tnet(m, outR)
- outT = MatrixToImage(outT.data.cpu().numpy().reshape(3, self.h, self.w).transpose(1, 2, 0))
-
- return outT
-
- def var_process(self, var):
- var = var.astype(np.float32)
- var = var.transpose(2, 0, 1)
- var = torch.from_numpy(var)
- var = var.unsqueeze(0)
- var = var / 255.0
- var = Variable(var)
- var = var.cuda()
- return var
-
- class Args():
- def __init__(self):
- self.data = './data'
- self.results = './results'
-
- if __name__ == "__main__":
- args = Args()
- data_path = args.data
- result_path = args.results
-
- data_list = os.listdir(data_path)
-
- pano_rr = Pano_RR()
- for index in data_list:
- m = cv2.imread(osp.join(data_path, index, 'm.jpg'))
- rs = cv2.imread(osp.join(data_path, index, 'rs.jpg'))
-
- pred = pano_rr(m, rs)
-
- save_path = osp.join(result_path, index)
- if not osp.isdir(save_path):
- os.makedirs(save_path)
- cv2.imwrite(osp.join(save_path, 'outT.jpg'), pred)
|