|
- import math
-
- import torch,cv2
- from imgviz import label_colormap
- from torch.utils.data import Dataset
- import json
- from tqdm import tqdm
- import os
- from PIL import Image
- from torchvision import transforms as T
-
- import torch.nn.functional as F
-
- from util.render_util import pose_spherical, circle, gen_path
- from util.ray_utils import *
-
-
- class ReplicaDataset(Dataset):
- def __init__(self,
- datadir,
- split='train',
- downsample=1.0,
- is_stack=False,
- N_vis=-1,
- use_semantic = False,
- scene_bbox = torch.tensor([[-6., -5.5, -3.], [6., 5.5, 3.]])
- ):
-
- self.N_vis = N_vis
- self.root_dir = datadir
- self.split = split
- self.is_stack = is_stack
- self.downsample = downsample
- self.define_transforms()
- self.use_semantic = use_semantic
- self.scene_bbox = scene_bbox
-
- # the type of replica is opencv
- # self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
- self.read_meta()
- self.define_proj_mat()
-
- self.white_bg = True
-
- self.near_far = [0.1,50.0]
-
- self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
- self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
- self.downsample=downsample
-
- def read_depth(self, filename):
- depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
- return depth
-
- def load_data(self, root_dir, img_eval_interval=5):
- w , h = self.img_wh
- squence = root_dir.split('/')[-1]
- with open(os.path.join(root_dir, f"traj_w_c.txt"), 'r') as f:
- # self.meta = json.load(f)
- all_c2w = np.loadtxt(f, delimiter=" ").reshape(-1, 4, 4)
- # img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
- # img_eval_interval = 5
- img_test_interval = 2
- N_image = 900
- idxs = list(range(0, N_image , img_eval_interval))
- for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)}) in {squence}'):#img_list:#
- if self.split == 'test':
- i = i + img_test_interval
- # frame = self.meta['frames'][i]
- # pose = np.array(frame['transform_matrix']) @ self.blender2opencv
- # c2w = torch.FloatTensor(pose)
- c2w = torch.FloatTensor(all_c2w[i])
- self.poses += [c2w]
-
- # image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
- img_fname = 'rgb_' + str(i) + '.png'
- image_path = os.path.join(root_dir, 'rgb', img_fname)
- self.image_paths += [image_path]
- img = Image.open(image_path)
-
- if self.downsample!=1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (4, h, w)
- img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
- if img.shape[-1]==4:
- img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
- self.all_rgbs += [img]
-
- if self.use_semantic:
- # todo : semantic image transform
- img_fname = 'semantic_class_' + str(i) + '.png'
- image_path = os.path.join(root_dir, 'semantic_class', img_fname)
- self.sem_paths += [image_path]
- img = Image.open(image_path)
-
- # self.sem_samples['semantic'].append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
- self.sem_samples['semantic'].append(np.array(img))
-
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (4, h, w)
- img = img.view(-1, w * h).permute(1, 0) # (h*w, 4) RGBA
- # if img.shape[-1] == 4:
- # img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
- self.all_sems += [img]
-
- rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
-
- def read_meta(self):
-
- # load intrin
- rsz_h, rsz_w = 480 , 640
- hfov = 90
- # the pin-hole camera has the same value for fx and fy
- fx = rsz_w / 2.0 / math.tan(math.radians(hfov / 2.0))
- fy = fx
- cx = (rsz_w - 1.0) / 2.0
- cy = (rsz_h - 1.0) / 2.0
- self.meta = {}
- self.meta['w'] = rsz_w
- self.meta['h'] = rsz_h
- self.meta['fx'] = fx
- self.meta['cx'] = cx
- self.meta['fy'] = fy
- self.meta['cy'] = cy
- self.sem_samples = {}
- self.sem_samples['semantic'] = []
- # load c2w
-
-
- w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
- self.img_wh = [w,h]
- self.focal_x = self.meta['fx'] # 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
- self.focal_y = self.meta['fy'] # 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
- self.cx, self.cy = self.meta['cx'],self.meta['cy']
-
-
- # ray directions for all pixels, same for all images (same H, W, focal)
- self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
- self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
- self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
-
-
- self.image_paths = []
- self.sem_paths = []
- self.poses = []
- self.all_rays = []
- self.all_rgbs = []
- self.all_sems = []
- self.all_masks = []
- self.all_depth = []
- name = self.root_dir.split('/')[-1] if self.root_dir.split('/')[-1] else \
- self.root_dir.split('/')[-2]
- if name not in ['Sequence_1' , 'Sequence_2']:
- root_dir1 = os.path.join(self.root_dir,'Sequence_1')
- root_dir2 = os.path.join(self.root_dir,'Sequence_2')
- self.load_data(root_dir1, 5)
- self.load_data(root_dir2, 5)
-
- else:
- self.load_data(self.root_dir)
-
- self.scene_bbox = self.get_scene_box(self.all_rays)
- self.poses = torch.stack(self.poses)
-
- center = torch.mean(self.scene_bbox, dim=0)
- radius = torch.norm(self.scene_bbox[1]-center)*0.1
- up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
- pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
- self.render_path = gen_path(pos_gen, up=up,frames=200)
- self.render_path[:, :3, 3] += center
-
- if not self.is_stack:
- self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
- self.all_sems = torch.cat(self.all_sems, 0) # (len(self.meta['frames])*h*w, 1)
- # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
- else:
- self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
- # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
- self.all_sems = torch.stack(self.all_sems, 0).reshape(-1,*self.img_wh[::-1], 1) # (len(self.meta['frames]),h,w,1)
-
- # self.train_num = len(idxs)
- # self.mask_ids = np.ones(self.train_num)
-
- def define_transforms(self):
- self.transform = T.ToTensor()
-
- def define_proj_mat(self):
- self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
-
- def world2ndc(self,points,lindisp=None):
- device = points.device
- return (points - self.center.to(device)) / self.radius.to(device)
-
- def __len__(self):
- return len(self.all_rgbs)
-
- def __getitem__(self, idx):
- if self.use_semantic:
- sample = {'rays': self.all_rays[idx],
- 'rgbs': self.all_rgbs[idx],
- 'sems': self.all_sems[idx]
- }
- else:
- sample = {'rays': self.all_rays[idx],
- 'rgbs': self.all_rgbs[idx]}
- return sample
- def get_sem_class_nums(self,train_sem,test_sem,scene_file):
- self.semantic_classes = np.unique(
- np.concatenate(
- (np.unique(train_sem),np.unique(test_sem))
- ).astype(np.uint8))
- self.num_semantic_class = self.semantic_classes.shape[0]
- # number of semantic classes, including the void class of 0
-
- json_class_mapping = os.path.join(scene_file, "info_semantic.json")
- with open(json_class_mapping, "r") as f:
- annotations = json.load(f)
- total_num_classes = len(annotations["classes"])
- self.colour_map_np = label_colormap(total_num_classes)[self.semantic_classes]
- # select the existing class from total colour map
-
- self.get_sem_remap()
- def get_sem_remap(self):
- # remap existing semantic class labels to continuous label ranging from 0 to num_class-1
- # self.sem_samples = {}
- self.sem_samples["semantic"] = np.asarray(self.sem_samples["semantic"])
- self.sem_samples["semantic_clean"] = self.sem_samples["semantic"].copy()
- self.sem_samples["semantic_remap"] = self.sem_samples["semantic"].copy()
- self.sem_samples["semantic_remap_clean"] = self.sem_samples["semantic_clean"].copy()
-
- for i in range(self.num_semantic_class):
- self.sem_samples["semantic_remap"][self.sem_samples["semantic"] == self.semantic_classes[i]] = i
- self.sem_samples["semantic_remap_clean"][self.sem_samples["semantic_clean"] == self.semantic_classes[i]] = i
-
- # def create_6_plane(self):
- # center = torch.mean(self.scene_bbox, dim=0) + torch.tensor([1.,0.,0.])
- # radius = torch.norm(self.scene_bbox[1]-center)*0.1
- # up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
- # pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
- # self.render_path = gen_path(pos_gen, up=up,frames=200)
- # self.render_path[:, :3, 3] += center
-
- def get_scene_box(self,all_rays):
- all_rays_o = torch.stack(all_rays)[...,:3]
- all_rays_o = all_rays_o.reshape(-1,3)
- scene_min = torch.min(all_rays_o,0)[0] - 5.0
- scene_max = torch.max(all_rays_o,0)[0] + 5.0
- scene_bbox = torch.stack([scene_min,scene_max]).reshape(-1,3)
- print("scene box : ", scene_bbox)
- return scene_bbox
-
- def reload_data(self,root_dir):
- self.all_rgbs = []
- self.all_sems = []
- w, h = self.img_wh
- N_image = 180
- img_eval_interval = 1
- idxs = list(range(0, N_image , img_eval_interval))
- for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)}'):
- img_fname = str(i) + '.png'
- image_path = os.path.join(root_dir, 'rgb', img_fname)
- self.image_paths += [image_path]
- img = Image.open(image_path)
-
- if self.downsample!=1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (4, h, w)
- img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
- if img.shape[-1]==4:
- img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
- self.all_rgbs += [img]
-
- if self.use_semantic:
- img_fname = 'semantic_class_' + str(i) + '.png'
- image_path = os.path.join(root_dir, 'semantic_class', img_fname)
- self.sem_paths += [image_path]
- img = Image.open(image_path)
- self.sem_samples['semantic'].append(np.array(img))
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (4, h, w)
- img = img.view(-1, w * h).permute(1, 0) # (h*w, 4) RGBA
- self.all_sems += [img]
- if not self.is_stack:
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
- self.all_sems = torch.cat(self.all_sems, 0) # (len(self.meta['frames])*h*w, 1)
- else:
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],
- 3) # (len(self.meta['frames]),h,w,3)
- self.all_sems = torch.stack(self.all_sems, 0).reshape(-1, *self.img_wh[::-1],
- 1) # (len(self.meta['frames]),h,w,1)
|