|
- from opt import config_parser
-
- from renderer import *
- from util.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
- from models.tensoRF import TensorVMSplit
- from models.tensoRF_Sem import TensorVMSplitSem
- from dataLoader import dataset_dict
- import sys
-
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
-
- class SimpleSampler:
- def __init__(self, total, batch):
- self.total = total
- self.batch = batch
- self.curr = total
- self.ids = None
-
- def nextids(self):
- self.curr+=self.batch
- if self.curr + self.batch > self.total:
- self.ids = torch.LongTensor(np.random.permutation(self.total))
- self.curr = 0
- return self.ids[self.curr:self.curr+self.batch]
-
-
- @torch.no_grad()
- def export_mesh(args):
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
-
- alpha,_ = tensorf.getDenseAlpha()
- convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)
-
-
- @torch.no_grad()
- def render_test(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
- if args.dataset_name == 'replica':
- use_semantic = args.use_semantic
- train_dataset = dataset(args.datadir,
- split='train',
- downsample=args.downsample_train,
- is_stack=False,
- use_semantic=use_semantic)
- test_dataset = dataset(args.datadir,
- split='test',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic=use_semantic)
- train_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
- test_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
-
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- if not os.path.exists(args.ckpt):
- print('the ckpt path does not exists!!')
- return
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- num_valid_semantic_class = train_dataset.num_semantic_class - 1
- tensorf = eval(args.model_name)(
- sem_dim=num_valid_semantic_class,
- **kwargs)
- tensorf.load(ckpt)
- renderer = tensorf.OctreeRender_trilinear_fast_sem
- if args.dataset_name == 'replica':
- tensorf.set_label_colour_map(args.scene_file)
-
- logfolder = os.path.dirname(args.ckpt)
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- PSNRs_test = tensorf.evaluation(train_dataset, args, renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
- print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
- tensorf.evaluation(test_dataset, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- if args.render_path:
- c2ws = test_dataset.render_path
- os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
- tensorf.evaluation_path(test_dataset, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- if __name__ == '__main__':
-
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
- print(args)
-
- if args.export_mesh:
- export_mesh(args)
- if args.render_test or args.render_path:
- render_test(args)
-
|