|
- import json
- import os
- import sys
-
- import imageio
- import torch
- from imgviz import label_colormap
- from tqdm.auto import tqdm
- from dataLoader.ray_utils import get_rays
- from util.semantic_util import calculate_segmentation_metrics
- from util.utils import *
- from dataLoader.ray_utils import ndc_rays_blender
- from .tensorBase import *
-
-
- class TensorVMSplitSemVM(TensorBase):
- def __init__(self, aabb, gridSize, device, sem_dim, sem_n_comp, **kargs):
- self.sem_dim = sem_dim
- self.ignore_label = -1
- self.sem_n_comp = sem_n_comp
- super(TensorVMSplitSemVM, self).__init__(aabb, gridSize, device, **kargs)
-
- def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1, tps_function=None):
- # sample points
- viewdirs = rays_chunk[:, 3:6]
- if ndc_ray:
- xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,
- N_samples=N_samples)
- dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
- rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
- dists = dists * rays_norm
- viewdirs = viewdirs / rays_norm
- else:
- xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,
- N_samples=N_samples)
- dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
- viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
- # xyz_sampled = self.normalize_coord(xyz_sampled)
- if tps_function:
- viewdirs = tps_function(viewdirs)
- viewdirs = viewdirs.clamp(-1, 1)
- xyz_sampled = self.normalize_coord(xyz_sampled)
- xyz_sampled = tps_function(xyz_sampled)
- xyz_sampled = xyz_sampled.clamp(-1, 1)
- xyz_sampled = self.normalize_coord_reverse(xyz_sampled)
- if self.alphaMask is not None:
- alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
- alpha_mask = alphas > 0
- ray_invalid = ~ray_valid
- ray_invalid[ray_valid] |= (~alpha_mask)
- ray_valid = ~ray_invalid
-
- sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
- rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
- sem = torch.zeros((*xyz_sampled.shape[:2], self.sem_dim), device=xyz_sampled.device)
-
- if ray_valid.any():
- xyz_sampled = self.normalize_coord(xyz_sampled)
- # if tps_function:
- # xyz_sampled = tps_function(xyz_sampled)
- # xyz_sampled = xyz_sampled.clamp(-1, 1)
- sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
-
- validsigma = self.feature2density(sigma_feature)
- sigma[ray_valid] = validsigma
-
- alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
-
- app_mask = weight > self.rayMarch_weight_thres
- # app_mask = weight > 0.
-
- if app_mask.any():
- app_features = self.compute_appfeature(xyz_sampled[app_mask])
- # app_features, semantic_features = self.compute_appfeature(xyz_sampled[app_mask])
- valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
- rgb[app_mask] = valid_rgbs
- semantic_features = self.compute_semanticfeature(xyz_sampled[app_mask])
- valid_sems = semantic_features # torch.relu(semantic_features + 0.5)
- sem[app_mask] = valid_sems
-
- acc_map = torch.sum(weight, -1)
- rgb_map = torch.sum(weight[..., None] * rgb, -2)
- sem_map = torch.sum(weight[..., None] * sem, -2)
- if white_bg or (is_train and torch.rand((1,)) < 0.5):
- rgb_map = rgb_map + (1. - acc_map[..., None])
- sem_map = sem_map + (1. - acc_map[..., None])
-
- rgb_map = rgb_map.clamp(0, 1)
- # sem_map = sem_map.clamp(0, 1)
- # sem_map = torch.nn.functional.softplus(sem_map)
- with torch.no_grad():
- depth_map = torch.sum(weight * z_vals, -1)
- depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
-
- return rgb_map, depth_map, sem_map
-
- def init_svd_volume(self, res, device):
- self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
- self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
- self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)
- # self.basis_mat = torch.nn.Parameter(
- # torch.empty(
- # sum(self.app_n_comp), self.app_dim, dtype=torch.float32, device=device
- # )
- # )
- self.sem_plane, self.sem_line = self.init_one_svd(self.sem_n_comp, self.gridSize, 0.1, device)
- self.sem_basis_mat = torch.nn.Linear(sum(self.sem_n_comp), self.sem_dim, bias=False).to(device)
- # self.semantic_mat = torch.nn.Parameter(
- # torch.empty(
- # sum(self.sem_n_comp), self.sem_dim, dtype=torch.float32, device=device
- # )
- # )
-
- def init_one_svd(self, n_component, gridSize, scale, device):
- plane_coef, line_coef = [], []
- for i in range(len(self.vecMode)):
- vec_id = self.vecMode[i]
- mat_id_0, mat_id_1 = self.matMode[i]
- plane_coef.append(torch.nn.Parameter(
- scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) #
- line_coef.append(
- torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
-
- return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
-
- def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001):
- grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
- {'params': self.density_plane, 'lr': lr_init_spatialxyz},
- {'params': self.app_line, 'lr': lr_init_spatialxyz},
- {'params': self.app_plane, 'lr': lr_init_spatialxyz},
- {'params': self.sem_line, 'lr': lr_init_spatialxyz},
- {'params': self.sem_plane, 'lr': lr_init_spatialxyz},
- ]
- if isinstance(self.basis_mat, torch.nn.Linear):
- grad_vars += [{'params': self.basis_mat.parameters(), 'lr': lr_init_network}]
- else:
- grad_vars += [{'params': self.basis_mat, 'lr': lr_init_spatialxyz}]
- if isinstance(self.sem_basis_mat, torch.nn.Linear):
- grad_vars += [{'params': self.sem_basis_mat.parameters(), 'lr': lr_init_network}]
- else:
- grad_vars += [{'params': self.sem_basis_mat, 'lr': lr_init_spatialxyz}]
- if isinstance(self.renderModule, torch.nn.Module):
- grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}]
- return grad_vars
-
- def vectorDiffs(self, vector_comps):
- total = 0
-
- for idx in range(len(vector_comps)):
- n_comp, n_size = vector_comps[idx].shape[1:-1]
-
- dotp = torch.matmul(vector_comps[idx].view(n_comp, n_size),
- vector_comps[idx].view(n_comp, n_size).transpose(-1, -2))
- non_diagonal = dotp.view(-1)[1:].view(n_comp - 1, n_comp + 1)[..., :-1]
- total = total + torch.mean(torch.abs(non_diagonal))
- return total
-
- def vector_comp_diffs(self):
-
- return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line) + self.vectorDiffs(self.sem_line)
- # return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
-
- def density_L1(self):
- total = 0
- for idx in range(len(self.density_plane)):
- total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[
- idx])) # + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
- return total
-
- def TV_loss_density(self, reg):
- total = 0
- for idx in range(len(self.density_plane)):
- total = total + reg(self.density_plane[idx]) * 1e-2 # + reg(self.density_line[idx]) * 1e-3
- return total
-
- def TV_loss_app(self, reg):
- total = 0
- for idx in range(len(self.app_plane)):
- total = total + reg(self.app_plane[idx]) * 1e-2 # + reg(self.app_line[idx]) * 1e-3
- return total
-
- def TV_loss_sem(self, reg):
- total = 0
- for idx in range(len(self.sem_plane)):
- total = total + reg(self.sem_plane[idx]) * 1e-2 # + reg(self.app_line[idx]) * 1e-3
- return total
-
- def compute_densityfeature(self, xyz_sampled):
-
- # plane + line basis
- coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
- xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
- coordinate_line = torch.stack(
- (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
- coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
- 1, 2)
-
- sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
- for idx_plane in range(len(self.density_plane)):
- plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1])
- line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1])
- sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)
-
- return sigma_feature
-
- def compute_semanticfeature(self, xyz_sampled):
- # plane + line basis
- coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
- xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
- coordinate_line = torch.stack(
- (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
- coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
- 1, 2)
-
- plane_coef_point, line_coef_point = [], []
- for idx_plane in range(len(self.sem_plane)):
- plane_coef_point.append(F.grid_sample(self.sem_plane[idx_plane], coordinate_plane[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1]))
- line_coef_point.append(F.grid_sample(self.sem_line[idx_plane], coordinate_line[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1]))
- plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
-
- feature = (plane_coef_point * line_coef_point).T
- if isinstance(self.sem_basis_mat, torch.nn.Linear):
- sem_feature = self.sem_basis_mat(feature)
- else:
- sem_feature = torch.matmul(feature, self.sem_basis_mat)
- return sem_feature
-
- def compute_appfeature(self, xyz_sampled):
-
- # plane + line basis
- coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
- xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
- coordinate_line = torch.stack(
- (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
- coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
- 1, 2)
-
- plane_coef_point, line_coef_point = [], []
- for idx_plane in range(len(self.app_plane)):
- plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1]))
- line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
- align_corners=True).view(-1, *xyz_sampled.shape[:1]))
- plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
-
- # return self.basis_mat((plane_coef_point * line_coef_point).T)
- feature = (plane_coef_point * line_coef_point).T
-
- if isinstance(self.basis_mat, torch.nn.Linear):
- app_feature = self.basis_mat(feature)
- else:
- app_feature = torch.matmul(feature, self.basis_mat)
-
- # app_feature = torch.matmul(feature, self.basis_mat)
- # semantic_feature = self.compute_semanticfeature(feature)
- # return app_feature , semantic_feature
- return app_feature
-
- @torch.no_grad()
- def up_sampling_VM(self, plane_coef, line_coef, res_target):
-
- for i in range(len(self.vecMode)):
- vec_id = self.vecMode[i]
- mat_id_0, mat_id_1 = self.matMode[i]
- plane_coef[i] = torch.nn.Parameter(
- F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
- align_corners=True))
- line_coef[i] = torch.nn.Parameter(
- F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
- return plane_coef, line_coef
-
- @torch.no_grad()
- def upsample_volume_grid(self, res_target):
- self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
- self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
- self.sem_plane, self.sem_line = self.up_sampling_VM(self.sem_plane, self.sem_line, res_target)
- self.update_stepSize(res_target)
- print(f'upsamping to {res_target}')
-
- @torch.no_grad()
- def shrink(self, new_aabb):
-
- print("====> shrinking ...")
- xyz_min, xyz_max = new_aabb
- t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
- # print(new_aabb, self.aabb)
- # print(t_l, b_r,self.alphaMask.alpha_volume.shape)
- t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
- b_r = torch.stack([b_r, self.gridSize]).amin(0)
-
- for i in range(len(self.vecMode)):
- mode0 = self.vecMode[i]
- self.density_line[i] = torch.nn.Parameter(
- self.density_line[i].data[..., t_l[mode0]:b_r[mode0], :]
- )
- self.app_line[i] = torch.nn.Parameter(
- self.app_line[i].data[..., t_l[mode0]:b_r[mode0], :]
- )
-
- self.sem_line[i] = torch.nn.Parameter(
- self.sem_line[i].data[..., t_l[mode0]:b_r[mode0], :]
- )
- mode0, mode1 = self.matMode[i]
- self.density_plane[i] = torch.nn.Parameter(
- self.density_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]]
- )
- self.app_plane[i] = torch.nn.Parameter(
- self.app_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]]
- )
- self.sem_plane[i] = torch.nn.Parameter(
- self.sem_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]]
- )
-
- if not torch.all(self.alphaMask.gridSize == self.gridSize):
- t_l_r, b_r_r = t_l / (self.gridSize - 1), (b_r - 1) / (self.gridSize - 1)
- correct_aabb = torch.zeros_like(new_aabb)
- correct_aabb[0] = (1 - t_l_r) * self.aabb[0] + t_l_r * self.aabb[1]
- correct_aabb[1] = (1 - b_r_r) * self.aabb[0] + b_r_r * self.aabb[1]
- print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
- new_aabb = correct_aabb
-
- newSize = b_r - t_l
- self.aabb = new_aabb
- self.update_stepSize((newSize[0], newSize[1], newSize[2]))
-
- @torch.no_grad()
- def filtering_rays(self, all_rays, all_rgbs, all_sems, N_samples=256, chunk=10240 * 5, bbox_only=False):
- print('========> filtering rays ...')
- tt = time.time()
-
- N = torch.tensor(all_rays.shape[:-1]).prod()
-
- mask_filtered = []
- idx_chunks = torch.split(torch.arange(N), chunk)
- for idx_chunk in idx_chunks:
- rays_chunk = all_rays[idx_chunk].to(self.device)
-
- rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
- if bbox_only:
- vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
- rate_a = (self.aabb[1] - rays_o) / vec
- rate_b = (self.aabb[0] - rays_o) / vec
- t_min = torch.minimum(rate_a, rate_b).amax(-1) # .clamp(min=near, max=far)
- t_max = torch.maximum(rate_a, rate_b).amin(-1) # .clamp(min=near, max=far)
- mask_inbbox = t_max > t_min
-
- else:
- xyz_sampled, _, _ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
- mask_inbbox = (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
-
- mask_filtered.append(mask_inbbox.cpu())
-
- mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
-
- print(f'Ray filtering done! takes {time.time() - tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
- return all_rays[mask_filtered], all_rgbs[mask_filtered], all_sems[mask_filtered]
-
- def normalize_coord_reverse(self, xyz_sampled):
- # xyz_sampled belongs to [-1,1]
- return (xyz_sampled + 1 ) / self.invaabbSize + self.aabb[0]
- # return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1
-
- def compute_volume(self, xyz_locs, gridSize ,length=1):
-
- if self.alphaMask is not None:
- alphas = self.alphaMask.sample_alpha(xyz_locs)
- alpha_mask = alphas > 0
- else:
- alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool)
-
- sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
-
- if alpha_mask.any():
- xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
- sigma_feature = self.compute_densityfeature(xyz_sampled)
- validsigma = self.feature2density(sigma_feature)
- sigma[alpha_mask] = validsigma
-
- alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1])
- rgb = torch.zeros((xyz_locs.shape[0], 27), device=xyz_locs.device)
- sem = torch.zeros((xyz_locs.shape[0], self.sem_dim), device=xyz_locs.device)
- #
- # sem = torch.ones_like(xyz_locs[:, 0], dtype=bool)
- # rgb = torch.ones_like(xyz_locs[:, 0], dtype=bool)
- if alpha_mask.any():
- app_features = self.compute_appfeature(xyz_sampled)
- # valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
- rgb[alpha_mask] = app_features
- semantic_features = self.compute_semanticfeature(xyz_sampled)
- sem[alpha_mask] = semantic_features
- alpha = alpha.view((gridSize[1], gridSize[2]))
- sigma = sigma.view((gridSize[1], gridSize[2]))
- sem = sem.view((gridSize[1], gridSize[2],self.sem_dim))
- rgb = rgb.view((gridSize[1], gridSize[2],27))
- return alpha,sigma,sem,rgb
- # return alpha,sem,rgb
-
- @torch.no_grad()
- def getVolume(self,gridSize,sparse = True):
-
- gridSize = self.gridSize if gridSize is None else gridSize
- gridSize = torch.tensor(gridSize).long()
- samples = torch.stack(torch.meshgrid(
- torch.linspace(0, 1, gridSize[0]),
- torch.linspace(0, 1, gridSize[1]),
- torch.linspace(0, 1, gridSize[2]),
- ), -1).to(self.device)
- dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
- # dense_xyz = dense_xyz
- # print(self.stepSize, self.distance_scale*self.aabbDiag)
- alpha = torch.zeros_like(dense_xyz[...,0])
- sigma = torch.zeros_like(dense_xyz[...,0])
- sem = torch.zeros((gridSize[0],gridSize[1],gridSize[2],self.sem_dim)).to(self.device)
- rgb = torch.zeros((gridSize[0],gridSize[1],gridSize[2],27)).to(self.device)
- for i in range(gridSize[0]):
- alpha[i],sigma[i],sem[i],rgb[i] = self.compute_volume(dense_xyz[i].view(-1, 3), gridSize,self.stepSize)
- dense_xyz = self.normalize_coord(dense_xyz)
- dense_xyz = (dense_xyz + 1) / 2.0 * gridSize
- # todo:check the volume for render!
- if sparse:
- level = 0.005
- dense_xyz = dense_xyz[alpha>level]
- sigma = sigma[alpha>level]
- sem = sem[alpha>level]
- rgb = rgb[alpha>level]
- del alpha,samples
- return sigma, sem , rgb, dense_xyz
-
-
- def save(self, path):
- kwargs = self.get_kwargs()
- kwargs.update({'sem_dim': self.sem_dim})
- kwargs.update({'sem_n_comp': self.sem_n_comp})
- print("save semantic TensoRF")
- # kwargs['semantic_n_comp'] = self.sem_n_comp
- # kwargs['sem_dim'] = self.sem_dim
-
- ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
- if self.alphaMask is not None:
- alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
- ckpt.update({'alphaMask.shape': alpha_volume.shape})
- ckpt.update({'alphaMask.mask': np.packbits(alpha_volume.reshape(-1))})
- ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
- torch.save(ckpt, path)
-
-
-
-
- def OctreeRender_trilinear_fast_sem(self,
- rays,
- chunk=4096,
- N_samples=-1,
- ndc_ray=False,
- white_bg=True,
- is_train=False,
- device='cuda',
- tps_function=None):
-
- rgbs, alphas, depth_maps, weights, uncertainties, sem_logit_maps = [], [], [], [], [], []
- N_rays_all = rays.shape[0]
- for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
- rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
-
- rgb_map, depth_map, sem_logit_map = self.forward(
- rays_chunk,
- is_train=is_train,
- white_bg=white_bg,
- ndc_ray=ndc_ray,
- N_samples=N_samples,
- tps_function=tps_function)
-
- rgbs.append(rgb_map)
- depth_maps.append(depth_map)
- sem_logit_maps.append(sem_logit_map)
-
- return torch.cat(rgbs), None, torch.cat(depth_maps), None, None, torch.cat(sem_logit_maps)
-
- @torch.no_grad()
- def evaluation_path(self, test_dataset, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
- white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
- logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1), dim=-1)
- sem_maps = []
- PSNRs, rgb_maps, depth_maps = [], [], []
- ssims, l_alex, l_vgg = [], [], []
- os.makedirs(savePath, exist_ok=True)
- os.makedirs(savePath + "/rgbd", exist_ok=True)
- os.makedirs(savePath + "/sem", exist_ok=True)
-
- try:
- tqdm._instances.clear()
- except Exception:
- pass
-
- near_far = test_dataset.near_far
- for idx, c2w in tqdm(enumerate(c2ws)):
-
- W, H = test_dataset.img_wh
-
- c2w = torch.FloatTensor(c2w)
- rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
- if ndc_ray:
- rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
- rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
-
- rgb_map, _, depth_map, _, _, sem_logit_map = renderer(rays, chunk=8192, N_samples=N_samples,
- ndc_ray=ndc_ray, white_bg=white_bg, device=device)
- rgb_map = rgb_map.clamp(0.0, 1.0)
-
- rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
-
- depth_map, _ = visualize_depth_numpy(depth_map.numpy(), near_far)
-
- rgb_map = (rgb_map.numpy() * 255).astype('uint8')
-
- # sem_logit_map = sem_logit_map.clamp(0.0, 1.0)
- sem_map = logits_2_label(sem_logit_map)
- sem_map = sem_map.cpu().numpy()
- sem_map_label = sem_map
- sem_map = self.label_colour_map[sem_map]
- sem_map = self.get_vis_sme_clean(sem_map, sem_map_label)
- sem_map = sem_map.reshape(H, W, 3).astype('uint8')
-
- sem_maps.append(sem_map)
- # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- rgb_maps.append(rgb_map)
- depth_maps.append(depth_map)
- if savePath is not None:
- imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
- rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
- imageio.imwrite(f'{savePath}/sem/{prtx}{idx:03d}.png', sem_map)
-
- imageio.mimwrite(f'{savePath}/{prtx}semantic_video.mp4', np.stack(sem_maps), fps=10, quality=8)
- imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=10, quality=8)
- imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=10, quality=8)
-
- if PSNRs:
- psnr = np.mean(np.asarray(PSNRs))
- if compute_extra_metrics:
- ssim = np.mean(np.asarray(ssims))
- l_a = np.mean(np.asarray(l_alex))
- l_v = np.mean(np.asarray(l_vgg))
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
- else:
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
-
- return PSNRs
-
- @torch.no_grad()
- def evaluation(self, dataset, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
- white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
- logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1), dim=-1)
- sem_maps = []
- PSNRs, rgb_maps, depth_maps = [], [], []
- ssims, l_alex, l_vgg = [], [], []
- sem_maps_metric = []
- os.makedirs(savePath, exist_ok=True)
- os.makedirs(savePath + "/rgbd", exist_ok=True)
- os.makedirs(savePath + "/sem", exist_ok=True)
-
- try:
- tqdm._instances.clear()
- except Exception:
- pass
-
- semantic_clean = dataset.sem_samples['semantic_remap_clean']
- if torch.is_tensor(semantic_clean):
- try:
- semantic_clean = semantic_clean.numpy().astype(int)
- except:
- semantic_clean = semantic_clean.cpu().numpy().astype(int)
- else:
- semantic_clean = semantic_clean.astype(int)
- semantic_clean = semantic_clean - 1
- # ignore_label = - 1
- near_far = dataset.near_far
- img_eval_interval = 1 if N_vis < 0 else max(dataset.all_rays.shape[0] // N_vis, 1)
- idxs = list(range(0, dataset.all_rays.shape[0], img_eval_interval))
- for idx, samples in tqdm(enumerate(dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
-
- W, H = dataset.img_wh
- rays = samples.view(-1, samples.shape[-1])
-
- rgb_map, _, depth_map, _, _, sem_logit_map = renderer(rays, chunk=4096, N_samples=N_samples,
- ndc_ray=ndc_ray, white_bg=white_bg, device=device)
-
- # add sem_logit_map for metric
- rgb_map = rgb_map.clamp(0.0, 1.0)
- rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
-
- # sem_logit_map = sem_logit_map.clamp(0.0, 1.0)
- sem_map = logits_2_label(sem_logit_map)
- sem_map = sem_map.cpu().numpy()
- sem_maps_metric.append(sem_map.reshape(H, W))
- sem_map = self.label_colour_map[sem_map]
- sem_map = sem_map.reshape(H, W, 3).astype('uint8')
- sem_map = self.get_vis_sme_clean(sem_map, semantic_clean[idx])
- sem_maps.append(sem_map)
-
- depth_map, _ = visualize_depth_numpy(depth_map.numpy(), near_far)
- if len(dataset.all_rgbs):
- gt_rgb = dataset.all_rgbs[idxs[idx]].view(H, W, 3)
- loss = torch.mean((rgb_map - gt_rgb) ** 2)
- PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
-
- if compute_extra_metrics:
- ssim = rgb_ssim(rgb_map, gt_rgb, 1)
- l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', self.device)
- l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', self.device)
- ssims.append(ssim)
- l_alex.append(l_a)
- l_vgg.append(l_v)
-
- rgb_map = (rgb_map.numpy() * 255).astype('uint8')
- # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- rgb_maps.append(rgb_map)
- depth_maps.append(depth_map)
-
- if savePath is not None:
- imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
- rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
- imageio.imwrite(f'{savePath}/sem/{prtx}{idx:03d}.png', sem_map)
-
- imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=10, quality=10)
- imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=10, quality=10)
-
- imageio.mimwrite(f'{savePath}/{prtx}semantic_video.mp4', np.stack(sem_maps), fps=10, quality=10)
-
- if PSNRs:
- psnr = np.mean(np.asarray(PSNRs))
- if compute_extra_metrics:
- ssim = np.mean(np.asarray(ssims))
- l_a = np.mean(np.asarray(l_alex))
- l_v = np.mean(np.asarray(l_vgg))
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
- else:
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
- sem_maps_metric = np.stack(sem_maps_metric)
- # miou, miou_valid_class, total_accuracy, class_average_accuracy, ious = calculate_segmentation_metrics(semantic_clean, sem_maps_metric, self.sem_dim, self.ignore_label)
- # print(f'======> train all miou: {miou} <========================')
- # print(f'======> train all miou_valid_class: {miou_valid_class} <========================')
- # print(f'======> train all total_accuracy: {total_accuracy} <========================')
- # print(f'======> train all class_average_accuracy: {class_average_accuracy} <========================')
- # print(f'======> train all ious: {ious} <========================')
- return PSNRs
-
- def set_label_colour_map(self, scene_file):
- info_mesh_file = os.path.join(scene_file, "info_semantic.json")
- with open(info_mesh_file, "r") as f:
- annotations = json.load(f)
- instance_id_to_semantic_label_id = np.array(annotations["id_to_label"])
- instance_id_to_semantic_label_id[instance_id_to_semantic_label_id <= 0] = 0
- semantic_classes = np.unique(instance_id_to_semantic_label_id)
- num_classes = len(semantic_classes) # including void class--0
- label_colour_map = label_colormap()[semantic_classes]
- valid_colour_map = label_colour_map[1:]
- self.label_colour_map = valid_colour_map
-
- def get_vis_sme_clean(self, vis_sem, semantic_clean_idx):
- vis_sem[semantic_clean_idx == self.ignore_label, :] = 0
- return vis_sem
-
- @torch.no_grad()
- def tps_rendering_eval(self, dataset, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
- white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda', tps_function=None):
- logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1), dim=-1)
- sem_maps = []
- sem_classes = []
- PSNRs, rgb_maps, depth_maps = [], [], []
- ssims, l_alex, l_vgg = [], [], []
- os.makedirs(savePath, exist_ok=True)
- os.makedirs(savePath + "/rgb", exist_ok=True)
- os.makedirs(savePath + "/rgbd", exist_ok=True)
- os.makedirs(savePath + "/sem", exist_ok=True)
-
- try:
- tqdm._instances.clear()
- except Exception:
- pass
-
- semantic_clean = dataset.sem_samples['semantic_remap_clean']
- if torch.is_tensor(semantic_clean):
- try:
- semantic_clean = semantic_clean.numpy().astype(int)
- except:
- semantic_clean = semantic_clean.cpu().numpy().astype(int)
- else:
- semantic_clean = semantic_clean.astype(int)
- semantic_clean = semantic_clean - 1
- # ignore_label = - 1
- near_far = dataset.near_far
- img_eval_interval = 1 if N_vis < 0 else max(dataset.all_rays.shape[0] // N_vis, 1)
- idxs = list(range(0, dataset.all_rays.shape[0], img_eval_interval))
- for idx, samples in tqdm(enumerate(dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
-
- W, H = dataset.img_wh
- rays = samples.view(-1, samples.shape[-1])
-
- rgb_map, _, depth_map, _, _, sem_logit_map = renderer(rays, chunk=4096, N_samples=N_samples,
- ndc_ray=ndc_ray, white_bg=white_bg, device=device,
- tps_function=tps_function)
-
- # add sem_logit_map for metric
- rgb_map = rgb_map.clamp(0.0, 1.0)
- rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
-
- # sem_logit_map = sem_logit_map.clamp(0.0, 1.0)
- sem_map = logits_2_label(sem_logit_map)
- sem_map = sem_map.cpu().numpy()
- sem_class = sem_map.reshape(H, W).astype('uint8')
- sem_classes.append(sem_class)
- sem_label = sem_map.copy()
- sem_map = self.label_colour_map[sem_map]
- sem_map = self.get_vis_sme_clean(sem_map, sem_label)
- sem_map = sem_map.reshape(H, W, 3).astype('uint8')
- sem_maps.append(sem_map)
-
- depth_map, _ = visualize_depth_numpy(depth_map.numpy(), near_far)
- if len(dataset.all_rgbs):
- gt_rgb = dataset.all_rgbs[idxs[idx]].view(H, W, 3)
- loss = torch.mean((rgb_map - gt_rgb) ** 2)
- PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
-
- if compute_extra_metrics:
- ssim = rgb_ssim(rgb_map, gt_rgb, 1)
- l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', self.device)
- l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', self.device)
- ssims.append(ssim)
- l_alex.append(l_a)
- l_vgg.append(l_v)
-
- rgb_map = (rgb_map.numpy() * 255).astype('uint8')
- # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- rgb_maps.append(rgb_map)
- depth_maps.append(depth_map)
-
- if savePath is not None:
- rgb_path = f'{savePath}/rgb/' + 'rgb_' + f'{idx:03d}.png'
- imageio.imwrite(f'{savePath}/rgb/{prtx}{idx:03d}.png', rgb_map)
- rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
- vis_sem_path = f'{savePath}/sem/' + 'vis_sem_class_' + f'{idx:03d}.png'
- imageio.imwrite(vis_sem_path, sem_map)
- sem_class_path = f'{savePath}/sem/' + 'semantic_class_' + f'{idx:03d}.png'
- imageio.imwrite(sem_class_path, sem_class)
- imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=10, quality=10)
- imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=10, quality=10)
-
- imageio.mimwrite(f'{savePath}/{prtx}semantic_video.mp4', np.stack(sem_maps), fps=10, quality=10)
-
- if PSNRs:
- psnr = np.mean(np.asarray(PSNRs))
- if compute_extra_metrics:
- ssim = np.mean(np.asarray(ssims))
- l_a = np.mean(np.asarray(l_alex))
- l_v = np.mean(np.asarray(l_vgg))
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
- else:
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
-
- return rgb_maps, depth_maps, sem_maps
- @torch.no_grad()
- def tps_rendering_eval_path(self, test_dataset, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
- white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda',tps_function=None):
- logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1), dim=-1)
- sem_maps = []
- PSNRs, rgb_maps, depth_maps = [], [], []
- ssims, l_alex, l_vgg = [], [], []
- os.makedirs(savePath, exist_ok=True)
- os.makedirs(savePath + "/rgbd", exist_ok=True)
- os.makedirs(savePath + "/sem", exist_ok=True)
-
- try:
- tqdm._instances.clear()
- except Exception:
- pass
-
- near_far = test_dataset.near_far
- for idx, c2w in tqdm(enumerate(c2ws)):
-
- W, H = test_dataset.img_wh
-
- c2w = torch.FloatTensor(c2w)
- rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
- if ndc_ray:
- rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
- rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
-
- rgb_map, _, depth_map, _, _, sem_logit_map = renderer(rays, chunk=8192, N_samples=N_samples,
- ndc_ray=ndc_ray, white_bg=white_bg, device=device,
- tps_function = tps_function)
- rgb_map = rgb_map.clamp(0.0, 1.0)
-
- rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
-
- depth_map, _ = visualize_depth_numpy(depth_map.numpy(), near_far)
-
- rgb_map = (rgb_map.numpy() * 255).astype('uint8')
-
- # sem_logit_map = sem_logit_map.clamp(0.0, 1.0)
- sem_map = logits_2_label(sem_logit_map)
- sem_map = sem_map.cpu().numpy()
- sem_map_label = sem_map
- sem_map = self.label_colour_map[sem_map]
- sem_map = self.get_vis_sme_clean(sem_map, sem_map_label)
- sem_map = sem_map.reshape(H, W, 3).astype('uint8')
-
- sem_maps.append(sem_map)
- # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- rgb_maps.append(rgb_map)
- depth_maps.append(depth_map)
- if savePath is not None:
- imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
- rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
- imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
- imageio.imwrite(f'{savePath}/sem/{prtx}{idx:03d}.png', sem_map)
-
- imageio.mimwrite(f'{savePath}/{prtx}semantic_video.mp4', np.stack(sem_maps), fps=10, quality=8)
- imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=10, quality=8)
- imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=10, quality=8)
-
- if PSNRs:
- psnr = np.mean(np.asarray(PSNRs))
- if compute_extra_metrics:
- ssim = np.mean(np.asarray(ssims))
- l_a = np.mean(np.asarray(l_alex))
- l_v = np.mean(np.asarray(l_vgg))
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
- else:
- np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
- return rgb_maps, depth_maps, sem_maps
|