|
- import os
-
- import torch
- import numpy as np
-
- from dataLoader import dataset_dict
- from models.tensoRF_SemVM import TensorVMSplitSemVM
- from opt_SemVM import config_parser
- from util.tps.defomer_utils import TPS_Deformer
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def save_model(args,tensorf):
- logfolder = f'{args.basedir}/{args.expname}'
- os.makedirs(logfolder, exist_ok=True)
- tensorf.save(f'{logfolder}/{args.expname}.th')
- print("save model successfully")
-
-
- def load_dataset(args):
- dataset = dataset_dict[args.dataset_name]
- if args.dataset_name == 'replica':
- use_semantic = args.use_semantic
- train_dataset = dataset(args.datadir,
- split='train',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic=use_semantic)
- test_dataset = dataset(args.datadir,
- split='test',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic=use_semantic)
- train_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
- test_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
-
- return train_dataset,test_dataset
-
-
- def load_model(args,dataset):
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
-
- if 'semantic_n_comp' not in kwargs.keys():
- semantic_n_comp = args.n_lamb_sem
- kwargs.update({'semantic_n_comp': semantic_n_comp})
- # kwargs['semantic_n_comp'] = semantic_n_comp
- if 'sem_dim' not in kwargs.keys():
- sem_dim = dataset.num_semantic_class - 1
- # sem_dim = 28
- kwargs.update({'sem_dim': sem_dim})
- # kwargs['sem_dim'] = sem_dim
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
- return tensorf
-
- def get_voxel_grid(tensorf):
- gridSize = tensorf.gridSize
- # 313 shape[0]
- print(gridSize)
- return gridSize
-
-
- def update_dataset(train_dataset,deformer):
- all_rays = train_dataset.all_rays
-
- pass
-
-
-
- if __name__ == '__main__':
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(48)
- np.random.seed(48)
-
- args = config_parser()
- logfolder = f'{args.basedir}/{args.expname}'
- args.ckpt = os.path.join(logfolder,args.expname + '.th')
- train_dataset,test_dataset = load_dataset(args)
- print(args)
- tensorf = load_model(args,train_dataset)
- gridSize = get_voxel_grid(tensorf)
-
- renderer = tensorf.OctreeRender_trilinear_fast_sem
-
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- if args.dataset_name == 'replica':
- tensorf.set_label_colour_map(args.scene_file)
-
- deformer = TPS_Deformer(gridSize,device)
-
- # train_dataset.all_rays
- # scene_box
-
-
- print("========> train_dataset c2w")
- tensorf.tps_rendering_eval(train_dataset, args, renderer, f'{logfolder}/imgs_tps_train_all_with_reverse/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device,
- tps_function=deformer.get_deform)
|