|
- from opt import config_parser
-
- from renderer import *
- from util.utils import *
- from torch.utils.tensorboard import SummaryWriter
- import datetime
- from models.tensoRF import TensorVMSplit
- from models.tensoRF_Sem import TensorVMSplitSem
- from dataLoader import dataset_dict
- import sys
-
-
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- renderer = OctreeRender_trilinear_fast
-
-
- class SimpleSampler:
- def __init__(self, total, batch):
- self.total = total
- self.batch = batch
- self.curr = total
- self.ids = None
-
- def nextids(self):
- self.curr+=self.batch
- if self.curr + self.batch > self.total:
- self.ids = torch.LongTensor(np.random.permutation(self.total))
- self.curr = 0
- return self.ids[self.curr:self.curr+self.batch]
-
-
- @torch.no_grad()
- def export_mesh(args):
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
-
- alpha,_ = tensorf.getDenseAlpha()
- convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)
-
-
- @torch.no_grad()
- def render_test(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- white_bg = test_dataset.white_bg
- ndc_ray = args.ndc_ray
-
- if not os.path.exists(args.ckpt):
- print('the ckpt path does not exists!!')
- return
-
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
-
- logfolder = os.path.dirname(args.ckpt)
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
- PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
- print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
- evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- if args.render_path:
- c2ws = test_dataset.render_path
- os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
- evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
- def reconstruction(args):
-
- # init dataset
- dataset = dataset_dict[args.dataset_name]
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- # init resolution
- upsamp_list = args.upsamp_list
- update_AlphaMask_list = args.update_AlphaMask_list
- n_lamb_sigma = args.n_lamb_sigma
- n_lamb_sh = args.n_lamb_sh
-
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
-
- # init log file
- os.makedirs(logfolder, exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/rgba', exist_ok=True)
- summary_writer = SummaryWriter(logfolder)
-
-
-
- # init parameters
- # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
- aabb = train_dataset.scene_bbox.to(device)
- reso_cur = N_to_reso(args.N_voxel_init, aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
-
-
- if args.ckpt is not None:
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device':device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
- else:
- tensorf = eval(args.model_name)(aabb, reso_cur, device,
- density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
- shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
- pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)
-
-
- grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
-
-
- #linear in logrithmic space
- N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
-
-
- torch.cuda.empty_cache()
- PSNRs,PSNRs_test = [],[0]
-
- allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
- if not args.ndc_ray:
- allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
- trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- L1_reg_weight = args.L1_weight_inital
- print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
- tvreg = TVLoss()
- print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")
-
-
- pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
- for iteration in pbar:
-
-
- ray_idx = trainingSampler.nextids()
- rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)
-
- #rgb_map, alphas_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,
- N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)
-
- loss = torch.mean((rgb_map - rgb_train) ** 2)
-
-
- # loss
- total_loss = loss
- if Ortho_reg_weight > 0:
- loss_reg = tensorf.vector_comp_diffs()
- total_loss += Ortho_reg_weight*loss_reg
- summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if L1_reg_weight > 0:
- loss_reg_L1 = tensorf.density_L1()
- total_loss += L1_reg_weight*loss_reg_L1
- summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
-
- if TV_weight_density>0:
- TV_weight_density *= lr_factor
- loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_app>0:
- TV_weight_app *= lr_factor
- loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
-
- loss = loss.detach().item()
-
- PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', loss, global_step=iteration)
-
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- # Print the current values of the losses.
- if iteration % args.progress_refresh_rate == 0:
- pbar.set_description(
- f'Iteration {iteration:05d}:'
- + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
- + f' mse = {loss:.6f}'
- )
- PSNRs = []
-
-
- if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
- PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
- prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)
- summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
-
-
-
- if iteration in update_AlphaMask_list:
-
- if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
- reso_mask = reso_cur
- new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
- trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
-
-
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
- tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 #0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
-
- tensorf.save(f'{logfolder}/{args.expname}.th')
-
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
- PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
- summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_path:
- c2ws = test_dataset.render_path
- # c2ws = test_dataset.poses
- print('========>',c2ws.shape)
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
-
-
- def semantic_reconstruction(args):
- # init dataset
- dataset = dataset_dict[args.dataset_name]
-
- if args.dataset_name == 'replica':
- use_semantic = args.use_semantic
- train_dataset = dataset(args.datadir,
- split='train',
- downsample=args.downsample_train,
- is_stack=False,
- use_semantic=use_semantic)
- test_dataset = dataset(args.datadir,
- split='test',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic=use_semantic)
- train_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
- test_dataset.get_sem_class_nums(
- train_sem=train_dataset.sem_samples["semantic"],
- test_sem=test_dataset.sem_samples["semantic"],
- scene_file=args.scene_file
- )
- else:
- train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
- test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
- white_bg = train_dataset.white_bg
- near_far = train_dataset.near_far
- ndc_ray = args.ndc_ray
-
- # init resolution
- upsamp_list = args.upsamp_list
- update_AlphaMask_list = args.update_AlphaMask_list
- n_lamb_sigma = args.n_lamb_sigma
- n_lamb_sh = args.n_lamb_sh
-
-
- if args.add_timestamp:
- logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
- else:
- logfolder = f'{args.basedir}/{args.expname}'
-
- # init log file
- os.makedirs(logfolder, exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
- os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/rgba', exist_ok=True)
- os.makedirs(f'{logfolder}/semantic', exist_ok=True)
-
- summary_writer = SummaryWriter(logfolder)
-
- # init parameters
- # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
- aabb = train_dataset.scene_bbox.to(device)
- reso_cur = N_to_reso(args.N_voxel_init, aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
-
- num_valid_semantic_class = train_dataset.num_semantic_class - 1
- if args.ckpt is not None:
- ckpt = torch.load(args.ckpt, map_location=device)
- kwargs = ckpt['kwargs']
- kwargs.update({'device': device})
- tensorf = eval(args.model_name)(**kwargs)
- tensorf.load(ckpt)
- else:
- tensorf = eval(args.model_name)(aabb, reso_cur, device,
- density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh,
- app_dim=args.data_dim_color, near_far=near_far,
- shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre,
- density_shift=args.density_shift, distance_scale=args.distance_scale,
- pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe,
- featureC=args.featureC, step_ratio=args.step_ratio,
- fea2denseAct=args.fea2denseAct,
- sem_dim = num_valid_semantic_class,
- )
- renderer = tensorf.OctreeRender_trilinear_fast_sem
- if args.dataset_name == 'replica':
- tensorf.set_label_colour_map(args.scene_file)
-
-
-
- grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
- if args.lr_decay_iters > 0:
- lr_factor = args.lr_decay_target_ratio ** (1 / args.lr_decay_iters)
- else:
- args.lr_decay_iters = args.n_iters
- lr_factor = args.lr_decay_target_ratio ** (1 / args.n_iters)
-
- print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
-
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- # linear in logrithmic space
- N_voxel_list = (torch.round(torch.exp(
- torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list) + 1))).long()).tolist()[
- 1:]
-
- torch.cuda.empty_cache()
- PSNRs, PSNRs_test = [], [0]
-
- allrays, allrgbs, allsems = train_dataset.all_rays, train_dataset.all_rgbs,train_dataset.sem_samples["semantic_remap"]
- allsems = allsems.reshape(allrgbs.shape[0],1)
- allsems = torch.tensor(allsems)
- if not args.ndc_ray:
- allrays, allrgbs, allsems = tensorf.filtering_rays(allrays, allrgbs,allsems, bbox_only=True)
- trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
-
- Ortho_reg_weight = args.Ortho_weight
- print("initial Ortho_reg_weight", Ortho_reg_weight)
-
- L1_reg_weight = args.L1_weight_inital
- print("initial L1_reg_weight", L1_reg_weight)
- TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
- tvreg = TVLoss()
- print(f"initial TV_weight density: {TV_weight_density} "
- f"appearance: {TV_weight_app} ")
- # f"semantic:{TV_weight_sem}")
-
-
- # crossentropy_loss for semantic
- ignore_label = -1
- CrossEntropyLoss = nn.CrossEntropyLoss(ignore_index=ignore_label)
- crossentropy_loss = lambda logit, label: CrossEntropyLoss(logit, label - 1)
- logits_2_label = lambda x: torch.argmax(torch.nn.functional.softmax(x, dim=-1), dim=-1)
-
-
- pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
- for iteration in pbar:
-
- ray_idx = trainingSampler.nextids()
- rays_train, rgb_train, sem_train = allrays[ray_idx], allrgbs[ray_idx].to(device), allsems[ray_idx].to(device)
-
- # rgb_map, alphas_map, depth_map, weights, uncertainty
- rgb_map, alphas_map, depth_map, weights, uncertainty,sem_map = renderer(rays_train, chunk=args.batch_size,
- N_samples=nSamples, white_bg=white_bg,
- ndc_ray=ndc_ray, device=device, is_train=True)
-
- loss_rgb_mse = torch.mean((rgb_map - rgb_train) ** 2)
- # todo: loss_CE_sem
-
- weights_sem = 4e-2
- loss_sem_CE = crossentropy_loss(sem_map, sem_train.squeeze().long()) * weights_sem
- # loss
- total_loss = loss_rgb_mse + loss_sem_CE
- if Ortho_reg_weight > 0:
- loss_reg = tensorf.vector_comp_diffs()
- total_loss += Ortho_reg_weight * loss_reg
- summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
- if L1_reg_weight > 0:
- loss_reg_L1 = tensorf.density_L1()
- total_loss += L1_reg_weight * loss_reg_L1
- summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
-
- if TV_weight_density > 0:
- TV_weight_density *= lr_factor
- loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
- if TV_weight_app > 0:
- TV_weight_app *= lr_factor
- loss_tv = tensorf.TV_loss_app(tvreg) * TV_weight_app
- total_loss = total_loss + loss_tv
- summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
-
- optimizer.zero_grad()
- total_loss.backward()
- optimizer.step()
- loss_sem_CE = loss_sem_CE.detach().item()
- loss_rgb_mse = loss_rgb_mse.detach().item()
-
- PSNRs.append(-10.0 * np.log(loss_rgb_mse) / np.log(10.0))
- summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
- summary_writer.add_scalar('train/mse', loss_rgb_mse, global_step=iteration)
- summary_writer.add_scalar('train/CE', loss_sem_CE, global_step=iteration)
- # todo: sem metrics?
-
- for param_group in optimizer.param_groups:
- param_group['lr'] = param_group['lr'] * lr_factor
-
- # Print the current values of the losses.
- if iteration % args.progress_refresh_rate == 0:
- pbar.set_description(
- f'Iteration {iteration:05d}:'
- + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
- + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
- + f' mse = {loss_rgb_mse:.6f}'
- + f' CE = {loss_sem_CE:.6f}'
- )
- PSNRs = []
- # todo: sem metrics?
-
- if iteration % args.vis_every == args.vis_every - 1 and args.N_vis != 0:
- PSNRs_test = tensorf.evaluation(test_dataset, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
- prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray,
- compute_extra_metrics=False)
- summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
-
- if iteration in update_AlphaMask_list:
-
- if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256 ** 3: # update volume resolution
- reso_mask = reso_cur
- new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
- if iteration == update_AlphaMask_list[0]:
- tensorf.shrink(new_aabb)
- # tensorVM.alphaMask = None
- L1_reg_weight = args.L1_weight_rest
- print("continuing L1_reg_weight", L1_reg_weight)
-
- if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
- # filter rays outside the bbox
- allrays, allrgbs, allsems = tensorf.filtering_rays(allrays, allrgbs, allsems)
- trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
-
- if iteration in upsamp_list:
- n_voxels = N_voxel_list.pop(0)
- reso_cur = N_to_reso(n_voxels, tensorf.aabb)
- nSamples = min(args.nSamples, cal_n_samples(reso_cur, args.step_ratio))
- tensorf.upsample_volume_grid(reso_cur)
-
- if args.lr_upsample_reset:
- print("reset lr to initial")
- lr_scale = 1 # 0.1 ** (iteration / args.n_iters)
- else:
- lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
- grad_vars = tensorf.get_optparam_groups(args.lr_init * lr_scale, args.lr_basis * lr_scale)
- optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
-
- tensorf.save(f'{logfolder}/{args.expname}.th')
-
-
- if args.render_train:
- os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
- train_dataset = dataset(args.datadir,
- split='train',
- downsample=args.downsample_train,
- is_stack=True,
- use_semantic = args.use_semantic)
- PSNRs_test = tensorf.evaluation(train_dataset, args, renderer, f'{logfolder}/imgs_train_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_test:
- os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
- PSNRs_test = tensorf.evaluation(test_dataset, args, renderer, f'{logfolder}/imgs_test_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
- summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
- print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
-
- if args.render_path:
- c2ws = test_dataset.render_path
- # c2ws = test_dataset.poses
- print('========>', c2ws.shape)
- os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
- tensorf.evaluation_path(test_dataset, c2ws, renderer, f'{logfolder}/imgs_path_all/',
- N_vis=-1, N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=device)
-
-
- if __name__ == '__main__':
-
- torch.set_default_dtype(torch.float32)
- torch.manual_seed(20211202)
- np.random.seed(20211202)
-
- args = config_parser()
- print(args)
-
- if args.export_mesh:
- export_mesh(args)
-
- if args.render_only and (args.render_test or args.render_path):
- render_test(args)
- elif args.use_semantic:
- semantic_reconstruction(args)
- else:
- reconstruction(args)
|