|
- from opt_SemVM import config_parser
-
- from renderer import *
- from util.tps.defomer_utils import TPS_Deformer
- from util.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
- from models.tensoRF import TensorVMSplit
- from models.tensoRF_Sem import TensorVMSplitSem
- from models.tensoRF_SemVM import TensorVMSplitSemVM
- from dataLoader import dataset_dict
- import sys
-
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
-
- class SimpleSampler:
- def __init__(self, total, batch):
- self.total = total
- self.batch = batch
- self.curr = total
- self.ids = None
-
- def nextids(self):
- self.curr+=self.batch
- if self.curr + self.batch > self.total:
- self.ids = torch.LongTensor(np.random.permutation(self.total))
- self.curr = 0
- return self.ids[self.curr:self.curr+self.batch]
-
-
- class Mesh():
- def __init__(self,tensorf):
- self.tensorf = tensorf
- def compute_alpha(self, xyz_locs, length=1,tps_function=None):
-
-
- if tps_function:
- xyz_locs = self.tensorf.normalize_coord(xyz_locs)
- xyz_locs = tps_function(xyz_locs)
- # xyz_locs = xyz_locs.clamp(-1, 1)
- xyz_locs = self.tensorf.normalize_coord_reverse(xyz_locs)
- if self.tensorf.alphaMask is not None:
- alphas = self.tensorf.alphaMask.sample_alpha(xyz_locs)
- alpha_mask = alphas > 0
- else:
- alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool)
-
- sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
-
- if alpha_mask.any():
- xyz_sampled = self.tensorf.normalize_coord(xyz_locs[alpha_mask])
- # todo: tps ?
- sigma_feature = self.tensorf.compute_densityfeature(xyz_sampled)
- validsigma = self.tensorf.feature2density(sigma_feature)
- sigma[alpha_mask] = validsigma
-
- alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1])
-
- return alpha
-
- @torch.no_grad()
- def getDenseAlpha(self,tps_function=None,gridSize=None):
- gridSize = self.tensorf.gridSize if gridSize is None else gridSize
-
- samples = torch.stack(torch.meshgrid(
- torch.linspace(0, 1, gridSize[0]),
- torch.linspace(0, 1, gridSize[1]),
- torch.linspace(0, 1, gridSize[2]),
- ), -1).to(self.tensorf.device)
- dense_xyz = self.tensorf.aabb[0] * (1-samples) + self.tensorf.aabb[1] * samples
-
- # dense_xyz = dense_xyz
- # print(self.stepSize, self.distance_scale*self.aabbDiag)
- alpha = torch.zeros_like(dense_xyz[...,0])
- for i in range(gridSize[0]):
- alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.tensorf.stepSize,tps_function).view((gridSize[1], gridSize[2]))
- return alpha, dense_xyz
-
- @torch.no_grad()
- def export_mesh(args,level = 0.005):
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
- mesh_extractor = Mesh(tensorf)
- deformer = TPS_Deformer(tensorf.gridSize, device)
- tps_function = deformer.get_deform
- alpha, _ = mesh_extractor.getDenseAlpha(tps_function)
- # alpha,_ = tensorf.getDenseAlpha()
- convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=level)
-
-
- @torch.no_grad()
- def render_test(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
- if args.dataset_name == 'replica':
- use_semantic = args.use_semantic
- train_dataset = dataset(args.datadir,
- split='train',
- downsample=args.downsample_train,
- is_stack=False,
- use_semantic=use_semantic)
- test_dataset = dataset(args.datadir,
- split='test',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic=use_semantic)
- train_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
- test_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
-
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- if not os.path.exists(args.ckpt):
- print('the ckpt path does not exists!!')
- return
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- num_valid_semantic_class = train_dataset.num_semantic_class - 1
- n_lamb_sem = args.n_lamb_sem
- tensorf = eval(args.model_name)(
- sem_dim=num_valid_semantic_class,
- semantic_n_comp=n_lamb_sem,
- **kwargs)
- tensorf.load(ckpt)
- renderer = tensorf.OctreeRender_trilinear_fast_sem
- if args.dataset_name == 'replica':
- tensorf.set_label_colour_map(args.scene_file)
-
- logfolder = os.path.dirname(args.ckpt)
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- PSNRs_test = tensorf.evaluation(train_dataset, args, renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
- print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
- tensorf.evaluation(test_dataset, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- if args.render_path:
- c2ws = test_dataset.render_path
- os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
- tensorf.evaluation_path(test_dataset, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- if __name__ == '__main__':
-
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
- logfolder = f'{args.basedir}/{args.expname}'
- args.ckpt = os.path.join(logfolder,args.expname + '.th')
- print(args)
- level = 0.001
- if args.export_mesh:
- export_mesh(args,level)
- elif args.render_test or args.render_path:
- render_test(args)
-
|