|
- import os
- from torch.utils.data import Dataset
- from imgviz import label_colormap
- from tqdm import tqdm
- from PIL import Image
- from torchvision import transforms as T
-
- from utils.ray_utils import *
- from utils.render_util import *
-
-
- class KITTI360Dataset(Dataset):
- def __init__(self, datadir,
- split, start, end,
- near, far,
- test_ids=None,
- downsample=1.0, is_stack=False, N_vis=-1,
- ndc_ray=False, use_01=False):
- self.N_vis = N_vis
- self.ndc_ray = ndc_ray
- self.use_01 = use_01
-
- self.root_dir = datadir
- self.split = split
- self.start = start
- self.end = end
- self.test_ids = test_ids
-
- self.is_stack = is_stack
-
- intrinsic_path = os.path.join(self.root_dir, "calibration", "perspective.txt")
- self.load_intrinsic(intrinsic_path)
- self.img_wh = (int(self.img_w / downsample), int(self.img_h / downsample))
-
- self.define_transforms()
-
- self.downsample = downsample
- # ndc near_far ?
- self.near_far = [near, far] # used in ray sampling(), calculating depth ?
-
- self.read_meta()
- self.define_proj_mat()
-
- self.white_bg = True
-
- def read_meta(self):
- w, h = self.img_wh
-
- # camera intrinsic params
- self.K[:2] = self.K[:2] / self.downsample
- self.intrinsics = torch.tensor(self.K[:, :-1]).float()
- # self.intrinsics = torch.tensor([[self.focal_x, 0, cx], [0, self.focal_y, cy], [0, 0, 1]]).float()
- self.focal_x, self.focal_y = self.intrinsics[0][0], self.intrinsics[1][1]
-
- # normalized ray directions for all pixels, same for all images (same H, W, focal)
- self.directions = get_ray_directions(h, w, [self.focal_x, self.focal_y]) # (h, w, 3)
- self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
-
- # camera extrinsic params
- self.c2w_dict = {}
- all_c2w_file = os.path.join(self.root_dir, "data_poses", "2013_05_28_drive_0000_sync", "cam0_to_world.txt")
- for line in open(all_c2w_file, 'r').readlines():
- value = list(map(float, line.strip().split(" ")))
- frame = value[0]
- self.c2w_dict[frame] = np.array(value[1:]).reshape(4, 4)
-
- self.image_paths = []
- self.sem_paths = []
- self.poses = []
- self.all_rays = []
- self.all_rgbs = []
- self.all_sems = []
- self.all_masks = []
- self.all_depth = []
- # for semantic labels remapping ?
- self.sem_samples = {}
- self.sem_samples["sem_img"] = []
-
- # split training and test datasets
- if self.split == "train":
- ids = list(range(self.start, self.end))
- elif self.split == "test":
- ids = self.test_ids
-
- # train in 00, test in 01 ?
-
- for i in tqdm(ids, desc=f'Loading data {self.split} ({len(ids)})'): # img_list:#
- c2w = torch.FloatTensor(self.c2w_dict[i])
- self.poses += [c2w]
-
- image_path = os.path.join(self.root_dir, "2013_05_28_drive_0000_sync", "image_00", "data_rect", f"000000{str(i)}.png")
- self.image_paths += [image_path]
- img = Image.open(image_path)
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (3, h, w)
- img = img.view(3, -1).permute(1, 0) # (3, h*w)
- self.all_rgbs += [img]
-
- sem_image_path = os.path.join(self.root_dir, "pspnet", "2013_05_28_drive_0000_sync", f"seq0_{str(self.start)}-{str(self.end)}", f"0000_000000{str(i)}.png")
- self.sem_paths += [sem_image_path]
- sem_img = Image.open(sem_image_path) # type: Image
- self.sem_samples["sem_img"].append(np.array(sem_img)) # pixel intensity: train_id ?
- if self.downsample != 1.0:
- sem_img = sem_img.resize(self.img_wh, Image.LANCZOS)
- sem_img = self.transform(sem_img)
- sem_img = sem_img.view(-1, h*w).permute(1, 0) # (1, h*w)
- self.all_sems += [sem_img]
-
- rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
-
- if self.use_01:
- self.K_01[:2] = self.K_01[:2] / self.downsample
- self.intrinsics_01 = self.K_01[:, :-1]
- self.focal_x_01, self.focal_y_01 = self.intrinsics_01[0][0], self.intrinsics_01[1][1]
- # ray directions for all pixels, same for all images (same H, W, focal)
- self.directions_01 = get_ray_directions(h, w, [self.focal_x_01, self.focal_y_01]) # (h, w, 3)
- self.directions_01 = self.directions_01 / torch.norm(self.directions_01, dim=-1, keepdim=True)
-
- # camera extrinsic params
- self.c2w_dict_01 = {}
-
- pose_file = os.path.join(self.root_dir, 'data_poses', "2013_05_28_drive_0000_sync", 'poses.txt')
- poses = np.loadtxt(pose_file)
- frames = poses[:, 0]
- poses = np.reshape(poses[:, 1:], [-1, 3, 4])
- fileCameraToPose = os.path.join(self.root_dir, "calibration", 'calib_cam_to_pose.txt')
- for line in open(fileCameraToPose, 'r').readlines():
- camera_name = line.strip().split(" ")[0]
- value = list(map(float, line.strip().split(" ")[1:]))
- if camera_name == "image_01:":
- camToPose = np.array(value).reshape(3, 4)
- break
- camToPose = np.concatenate((camToPose, np.array([0., 0., 0., 1.]).reshape(1, 4)))
- for frame, pose in zip(frames, poses):
- pose = np.concatenate((pose, np.array([0., 0., 0., 1.]).reshape(1, 4)))
- self.c2w_dict_01[frame] = np.matmul(np.matmul(pose, camToPose), np.linalg.inv(self.R_rect))
-
- for i in tqdm(ids, desc=f'Loading 01 data {self.split} ({len(ids)})'):
- c2w_01 = torch.FloatTensor(self.c2w_dict_01[i])
- self.poses += [c2w_01]
-
- image_path = os.path.join(self.root_dir, "2013_05_28_drive_0000_sync", "image_01", "data_rect", f"000000{str(i)}.png")
- self.image_paths += [image_path]
- img = Image.open(image_path)
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (3, h, w)
- img = img.view(3, -1).permute(1, 0) # (3, h*w)
- self.all_rgbs += [img]
-
- rays_o, rays_d = get_rays(self.directions_01, c2w_01) # both (h*w, 3)
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
-
- # used in train.py for aabb
- # generate adaptive scene_bbox
- all_rays_o = torch.stack(self.all_rays)[..., :3]
- all_rays_o = all_rays_o.reshape(-1, 3)
- scene_min = torch.min(all_rays_o, 0)[0] - 50.0 # add to args
- scene_max = torch.max(all_rays_o, 0)[0] + 50.0
- self.scene_bbox = torch.stack([scene_min, scene_max]).reshape(-1, 3)
-
- self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
- self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
-
- self.poses = torch.stack(self.poses)
-
- # training
- if not self.is_stack:
- self.all_rays = torch.cat(self.all_rays, 0) # (num_imgs*h*w, 6)
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (num_imgs*h*w, 3)
- self.all_sems = torch.cat(self.all_sems, 0) # (num_imgs*h*w, 1)
- # evaluation and test
- else:
- self.all_rays = torch.stack(self.all_rays, 0) # (num_imgs,h*w, 6)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (num_imgs,h,w,3)
- self.all_sems = torch.stack(self.all_sems, 0).reshape(-1, *self.img_wh[::-1], 1) # (num_imgs,h,w,1)
-
- if self.ndc_ray:
- ndc_rays_o, ndc_rays_d = ndc_rays(self.img_h, self.img_w, self.focal_x, self.near_far[0], self.all_rays[:, :3], self.all_rays[:, 3:6])
- self.all_rays = torch.cat([ndc_rays_o, ndc_rays_d], dim=-1)
-
- center = torch.mean(self.scene_bbox, dim=0)
- radius = torch.norm(self.scene_bbox[1]-center)*0.1
- up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
- pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
- self.render_path = gen_path(pos_gen, up=up,frames=200)
- self.render_path[:, :3, 3] += center
-
- def load_intrinsic(self, intrinsic_file):
- with open(intrinsic_file) as f:
- intrinsics = f.read().splitlines()
- for line in intrinsics:
- line = line.split(' ')
- if line[0] == 'P_rect_00:':
- K = [float(x) for x in line[1:]]
- K = np.reshape(K, [3, 4])
- self.K = K
- elif line[0] == 'P_rect_01:':
- K = [float(x) for x in line[1:]]
- K = np.reshape(K, [3, 4])
- intrinsic_loaded = True
- self.K_01 = K
- elif line[0] == 'R_rect_01:':
- R_rect = np.eye(4)
- R_rect[:3, :3] = np.array([float(x) for x in line[1:]]).reshape(3, 3)
- elif line[0] == "S_rect_01:":
- width = int(float(line[1]))
- height = int(float(line[2]))
- assert (intrinsic_loaded == True)
- assert (width > 0 and height > 0)
- self.img_w, self.img_h = width, height
- self.R_rect = R_rect
-
- def remap_sem_label(self, train_sem_imgs, test_sem_imgs):
- # number of existed classes in training and test dataset
- self.semantic_classes = np.unique(np.concatenate((np.unique(train_sem_imgs), np.unique(test_sem_imgs))).astype(np.uint8)) # change when use 01 ?
- self.num_semantic_class = self.semantic_classes.shape[0]
- self.num_valid_semantic_class = self.num_semantic_class
-
- self.sem_samples["sem_img"] = np.asarray(self.sem_samples["sem_img"])
- self.sem_samples["sem_remap"] = self.sem_samples["sem_img"].copy()
-
- for i in range(self.num_semantic_class):
- self.sem_samples["sem_remap"][self.sem_samples["sem_img"] == self.semantic_classes[i]] = i
-
- def set_label_colour_map(self):
- # change map of labels (1, h, w) to images(3, h, w)
- valid_colour_map = label_colormap()[self.semantic_classes]
- self.label_colour_map = valid_colour_map
-
- def define_transforms(self):
- self.transform = T.ToTensor()
-
- def define_proj_mat(self):
- self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
-
- def world2ndc(self, points, lindisp=None):
- device = points.device
- return (points - self.center.to(device)) / self.radius.to(device)
-
- def get_sem_loss(self, sem_map, sem_train):
- sem_loss_fun = torch.nn.CrossEntropyLoss()
-
- return sem_loss_fun(sem_map, sem_train.squeeze().long())
-
- def __len__(self):
- return len(self.all_rgbs)
-
- def __getitem__(self, idx):
- img = self.all_rgbs[idx]
- rays = self.all_rays[idx]
- sems = self.all_sems[idx]
-
- sample = {'rays': rays,
- 'rgbs': img,
- 'sems': sems}
-
- return sample
-
|