|
- import torch
- from torch.utils.data import Dataset
- import glob
- import numpy as np
- import os
- from PIL import Image
- from torchvision import transforms as T
-
- from .ray_utils import *
-
-
- def normalize(v):
- """Normalize a vector."""
- return v / np.linalg.norm(v)
-
-
- def average_poses(poses):
- """
- Calculate the average pose, which is then used to center all poses
- using @center_poses. Its computation is as follows:
- 1. Compute the center: the average of pose centers.
- 2. Compute the z axis: the normalized average z axis.
- 3. Compute axis y': the average y axis.
- 4. Compute x' = y' cross product z, then normalize it as the x axis.
- 5. Compute the y axis: z cross product x.
-
- Note that at step 3, we cannot directly use y' as y axis since it's
- not necessarily orthogonal to z axis. We need to pass from x to y.
- Inputs:
- poses: (N_images, 3, 4)
- Outputs:
- pose_avg: (3, 4) the average pose
- """
- # 1. Compute the center
- center = poses[..., 3].mean(0) # (3)
-
- # 2. Compute the z axis
- z = normalize(poses[..., 2].mean(0)) # (3)
-
- # 3. Compute axis y' (no need to normalize as it's not the final output)
- y_ = poses[..., 1].mean(0) # (3)
-
- # 4. Compute the x axis
- x = normalize(np.cross(z, y_)) # (3)
-
- # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
- y = np.cross(x, z) # (3)
-
- pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
-
- return pose_avg
-
-
- def center_poses(poses, blender2opencv):
- """
- Center the poses so that we can use NDC.
- See https://github.com/bmild/nerf/issues/34
- Inputs:
- poses: (N_images, 3, 4)
- Outputs:
- poses_centered: (N_images, 3, 4) the centered poses
- pose_avg: (3, 4) the average pose
- """
- poses = poses @ blender2opencv
- pose_avg = average_poses(poses) # (3, 4)
- pose_avg_homo = np.eye(4)
- pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
- pose_avg_homo = pose_avg_homo
- # by simply adding 0, 0, 0, 1 as the last row
- last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
- poses_homo = \
- np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
-
- poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
- # poses_centered = poses_centered @ blender2opencv
- poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
-
- return poses_centered, pose_avg_homo
-
-
- def viewmatrix(z, up, pos):
- vec2 = normalize(z)
- vec1_avg = up
- vec0 = normalize(np.cross(vec1_avg, vec2))
- vec1 = normalize(np.cross(vec2, vec0))
- m = np.eye(4)
- m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
- return m
-
-
- def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
- render_poses = []
- rads = np.array(list(rads) + [1.])
-
- for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
- c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
- z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
- render_poses.append(viewmatrix(z, up, c))
- return render_poses
-
-
- def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
- # center pose
- c2w = average_poses(c2ws_all)
-
- # Get average pose
- up = normalize(c2ws_all[:, :3, 1].sum(0))
-
- # Find a reasonable "focus depth" for this dataset
- dt = 0.75
- close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
- focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
-
- # Get radii for spiral path
- zdelta = near_fars.min() * .2
- tt = c2ws_all[:, :3, 3]
- rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
- render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
- return np.stack(render_poses)
-
-
- class LLFFDataset(Dataset):
- def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):
- """
- spheric_poses: whether the images are taken in a spheric inward-facing manner
- default: False (forward-facing)
- val_num: number of val images (used for multigpu training, validate same image for all gpus)
- """
-
- self.root_dir = datadir
- self.split = split
- self.hold_every = hold_every
- self.is_stack = is_stack
- self.downsample = downsample
- self.define_transforms()
-
- self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
- self.read_meta()
- self.white_bg = False
-
- # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
- self.near_far = [0.0, 1.0]
- self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
- # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
- self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
- self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
-
- def read_meta(self):
-
-
- poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
- self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
- # load full resolution image then resize
- if self.split in ['train', 'test']:
- assert len(poses_bounds) == len(self.image_paths), \
- 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
-
- poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
- self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
- hwf = poses[:, :, -1]
-
- # Step 1: rescale focal length according to training resolution
- H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
- self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
- self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
-
- # Step 2: correct poses
- # Original poses has rotation in form "down right back", change to "right up back"
- # See https://github.com/bmild/nerf/issues/34
- poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
- # (N_images, 3, 4) exclude H, W, focal
- self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
-
- # Step 3: correct scale so that the nearest depth is at a little more than 1.0
- # See https://github.com/bmild/nerf/issues/34
- near_original = self.near_fars.min()
- scale_factor = near_original * 0.75 # 0.75 is the default parameter
- # the nearest depth is at 1/0.75=1.33
- self.near_fars /= scale_factor
- self.poses[..., 3] /= scale_factor
-
- # build rendering path
- N_views, N_rots = 120, 2
- tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
- up = normalize(self.poses[:, :3, 1].sum(0))
- rads = np.percentile(np.abs(tt), 90, 0)
-
- self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
-
- # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
- # val_idx = np.argmin(distances_from_center) # choose val image as the closest to
- # center image
-
- # ray directions for all pixels, same for all images (same H, W, focal)
- W, H = self.img_wh
- self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
-
- average_pose = average_poses(self.poses)
- dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
- i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
- img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
-
- # use first N_images-1 to train, the LAST is val
- self.all_rays = []
- self.all_rgbs = []
- for i in img_list:
- image_path = self.image_paths[i]
- c2w = torch.FloatTensor(self.poses[i])
-
- img = Image.open(image_path).convert('RGB')
- if self.downsample != 1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (3, h, w)
-
- img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
- self.all_rgbs += [img]
- rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
- rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
- # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
-
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
-
- if not self.is_stack:
- self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
- else:
- self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
-
-
- def define_transforms(self):
- self.transform = T.ToTensor()
-
- def __len__(self):
- return len(self.all_rgbs)
-
- def __getitem__(self, idx):
-
- sample = {'rays': self.all_rays[idx],
- 'rgbs': self.all_rgbs[idx]}
-
- return sample
|