|
- import os
- import h5py
- import pickle
- import numpy as np
- from torch.utils.data import Dataset
- from imgviz import label_colormap
- from tqdm import tqdm
- from PIL import Image
- from torchvision import transforms as T
-
- from util.ray_utils import *
- from util.render_util import *
-
-
- class ReplicaDatasetDMNeRF(Dataset):
- def __init__(self, datadir, split='train',
- near=0.1, far=50,
- scene_bbox_stretch=5.5,
- downsample=1.0, is_stack=False, N_vis=-1,
- train_gen=False,
- img_total_num = 900,
- sem_info_path = None
- ):
- self.N_vis = N_vis
-
- self.root_dir = datadir
- self.split = split
- self.is_stack = is_stack
- self.train_gen = train_gen
-
- img_w, img_h = 640, 480
- self.img_wh = (int(img_w / downsample), int(img_h / downsample))
- self.define_transforms() # tensor transforms
-
- self.img_total_num = img_total_num
-
- # replica near_far
- self.near_far = [near, far] # used in sample_ray(tensorBase.py) for clipping samples, near must be 0.1 ?
- self.scene_bbox_stretch = scene_bbox_stretch
- self.sem_info_path = sem_info_path
- self.read_meta()
- self.define_proj_mat()
-
- self.white_bg = True
- self.downsample = downsample
-
- def read_meta(self):
- w, h = self.img_wh
-
- hfov = 90
- self.focal_x = 0.5 * w / np.tan(0.5 * np.radians(hfov)) # w ?
- self.focal_y = self.focal_x
- cx = (w - 1.) / 2
- cy = (h - 1.) / 2
-
- # ray directions for all pixels, same for all images (same H, W, focal)
- self.directions = get_ray_directions(h, w, [self.focal_x, self.focal_y]) # (h, w, 3)
- self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
- self.intrinsics = torch.tensor([[self.focal_x, 0, cx], [0, self.focal_y, cy], [0, 0, 1]]).float()
-
- # load c2w for all images in the video
- traj_file = os.path.join(self.root_dir, "traj_w_c.txt")
- self.Ts_full = np.loadtxt(traj_file, delimiter=" ").reshape(-1, 4, 4)
-
- self.image_paths = []
- self.sem_paths = []
- self.poses = []
- self.all_rays = []
- self.all_rgbs = []
- self.all_sems = []
- self.all_masks = []
- self.all_depth = []
- self.downsample = 1.0
- # for semantic labels remapping ?
- self.sem_samples = {}
- self.sem_samples["sem_img"] = []
- self.sem_samples["label_ins_map"] = {}
- self.sem_samples["ins_label_map"] = {}
-
- if self.train_gen:
- self.idxs = list(range(0, self.img_total_num))
- else:
- img_eval_interval = 5
- if self.split == "train":
- self.idxs = list(range(0, self.img_total_num, img_eval_interval))
- elif self.split == "test":
- self.idxs = list(range(img_eval_interval // 2, self.img_total_num, img_eval_interval))
-
- for i in tqdm(self.idxs, desc=f'Loading data {self.split} ({len(self.idxs)})'): # img_list:#
- c2w = torch.FloatTensor(self.Ts_full[i])
- self.poses += [c2w]
-
- image_path = os.path.join(self.root_dir, "rgb", f"rgb_{str(i)}.png")
- self.image_paths += [image_path]
- img = Image.open(image_path)
-
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (3, h, w)
- img = img.view(3, -1).permute(1, 0) # (3, h*w)
- self.all_rgbs += [img]
-
- sem_image_path = os.path.join(self.root_dir, 'semantic_instance', f"semantic_instance_{str(i)}.png")
- self.sem_paths += [sem_image_path]
- sem_img = Image.open(sem_image_path) # type: Image
- self.sem_samples["sem_img"].append(np.array(sem_img))
-
- if self.downsample != 1.0:
- sem_img = sem_img.resize(self.img_wh, Image.LANCZOS)
- sem_img = self.transform(sem_img)
- sem_img = sem_img.view(-1, h*w).permute(1, 0) # (1, h*w)
- self.all_sems += [sem_img]
-
- rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
-
- # used in train.py for aabb
- # adaptive scene_bbox
- all_rays_o = torch.stack(self.all_rays)[..., :3] # for all images, (N_imgs, h*w, 3)
- all_rays_o = all_rays_o.reshape(-1, 3)
-
- scene_min = torch.min(all_rays_o, 0)[0] - self.scene_bbox_stretch
- scene_max = torch.max(all_rays_o, 0)[0] + self.scene_bbox_stretch
-
- self.scene_bbox = torch.stack([scene_min, scene_max]).reshape(-1, 3)
-
- self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
- self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
-
- self.poses = torch.stack(self.poses)
-
- if not self.is_stack:
- # pixel-wise in training
- self.all_rays = torch.cat(self.all_rays, 0) # (num_imgs*h*w, 3)
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (num_imgs*h*w, 3)
- self.all_sems = torch.cat(self.all_sems, 0) # (num_imgs*h*w, 1)
- else:
- # image-wise in testing
- self.all_rays = torch.stack(self.all_rays, 0) # (num_imgs,h*w, 3)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (num_imgs,h,w,3)
- self.all_sems = torch.stack(self.all_sems, 0).reshape(-1, *self.img_wh[::-1], 1) # (num_imgs,h,w,1)
-
- # render images from new view-points
- center = torch.mean(self.scene_bbox, dim=0)
- radius = torch.norm(self.scene_bbox[1]-center)*0.1
- up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
- pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
- self.render_path = gen_path(pos_gen, up=up, frames=200)
- self.render_path[:, :3, 3] += center
-
- def save_c2ws(self, save_path):
- for i in self.idxs:
- c2w_array = self.Ts_full[i].reshape(1, -1)
- os.makedirs(save_path, exist_ok=True)
-
- if self.split == "train":
- with open(f'{save_path}/train_traj_w_c.txt', 'ab') as f:
- np.savetxt(f, c2w_array, delimiter=" ")
- elif self.split == "test":
- with open(f'{save_path}/test_traj_w_c.txt', 'ab') as f:
- np.savetxt(f, c2w_array, delimiter=" ")
-
- def remap_sem_gt_label(self, train_sem_imgs, test_sem_imgs, sem_info_path, save_map=False,init_semantic_class = True):
- if init_semantic_class:
- self.semantic_classes = np.unique(np.concatenate((np.unique(train_sem_imgs), np.unique(test_sem_imgs))).astype(np.uint8))
- self.num_semantic_class = self.semantic_classes.shape[0]
- self.num_valid_semantic_class = self.num_semantic_class
-
- self.sem_samples["sem_img"] = np.asarray(self.sem_samples["sem_img"])
- self.sem_samples["sem_remap"] = self.sem_samples["sem_img"].copy()
-
- for i in range(self.num_semantic_class):
- self.sem_samples["sem_remap"][self.sem_samples["sem_img"] == self.semantic_classes[i]] = i
- self.sem_samples["label_ins_map"][i] = self.semantic_classes[i]
- self.sem_samples["ins_label_map"][self.semantic_classes[i]] = i
-
- if save_map:
- with open(self.root_dir+"semantic_instance/label2ins_map.pkl", "wb") as f:
- pickle.dump(self.sem_samples["label_ins_map"], f)
-
- with open(self.root_dir+"semantic_instance/ins2label_map.pkl", "wb") as f:
- pickle.dump(self.sem_samples["ins_label_map"], f)
-
- def set_label_colour_map(self, sem_info_path = None):
- if not sem_info_path:
- sem_info_path = self.sem_info_path
- color_f = os.path.join(sem_info_path, 'ins_rgb.hdf5')
- with h5py.File(color_f, 'r') as f:
- ins_rgbs = f['datasets'][:] # ndarray
- f.close()
-
- def label_colour_map(sem_map):
- color_map = np.zeros(shape=(int(self.img_wh[0] * self.img_wh[1]), 3))
- for label in np.unique(sem_map):
- valid_label_list = list(range(0, ins_rgbs.shape[0]))
- if label in valid_label_list:
- color_map[sem_map == label] = ins_rgbs[label]
- return color_map
-
- self.label_colour_map = label_colour_map
-
- def inv_map_sem_gt_label(self,
- sem_map # (H*W)
- ):
- # todo: check before tps
- gt_sem_map = np.zeros_like(sem_map)
- for remap_sem_value in self.sem_samples["label_ins_map"].keys():
- if remap_sem_value in sem_map:
- gt_sem_map[sem_map == remap_sem_value] = self.sem_samples["label_ins_map"][remap_sem_value]
-
- return gt_sem_map
-
- def define_transforms(self):
- self.transform = T.ToTensor()
-
- def define_proj_mat(self):
- self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
-
- def world2ndc(self, points, lindisp=None):
- device = points.device
- return (points - self.center.to(device)) / self.radius.to(device)
-
- def get_sem_loss(self, sem_map, sem_train):
- sem_loss_fun = torch.nn.CrossEntropyLoss()
- sem_loss = lambda logit, label: sem_loss_fun(logit, label)
-
- return sem_loss(sem_map, sem_train.squeeze().long())
-
- def __len__(self):
- return len(self.all_rgbs)
-
- def __getitem__(self, idx):
- img = self.all_rgbs[idx]
- rays = self.all_rays[idx]
- sems = self.all_sems[idx]
-
- sample = {'rays': rays,
- 'rgbs': img,
- 'sems': sems}
-
- return sample
-
|