|
- import torch
- from torch.utils.data import Dataset
- from tqdm import tqdm
- import os
- from PIL import Image
- from torchvision import transforms as T
-
- from .ray_utils import *
-
- trans_t = lambda t : torch.Tensor([
- [1,0,0,0],
- [0,1,0,0],
- [0,0,1,t],
- [0,0,0,1]]).float()
-
- rot_phi = lambda phi : torch.Tensor([
- [1,0,0,0],
- [0,np.cos(phi),-np.sin(phi),0],
- [0,np.sin(phi), np.cos(phi),0],
- [0,0,0,1]]).float()
-
- rot_theta = lambda th : torch.Tensor([
- [np.cos(th),0,-np.sin(th),0],
- [0,1,0,0],
- [np.sin(th),0, np.cos(th),0],
- [0,0,0,1]]).float()
-
-
- def pose_spherical(theta, phi, radius):
- c2w = trans_t(radius)
- c2w = rot_phi(phi/180.*np.pi) @ c2w
- c2w = rot_theta(theta/180.*np.pi) @ c2w
- c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
- return c2w
-
- class NSVF(Dataset):
- """NSVF Generic Dataset."""
- def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
- self.root_dir = datadir
- self.split = split
- self.is_stack = is_stack
- self.downsample = downsample
- self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
- self.define_transforms()
-
- self.white_bg = True
- self.near_far = [0.5,6.0]
- self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
- self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
- self.read_meta()
- self.define_proj_mat()
-
- self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
- self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
-
- def bbox2corners(self):
- corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
- for i in range(3):
- corners[i,[0,1],i] = corners[i,[1,0],i]
- return corners.view(-1,3)
-
-
- def read_meta(self):
- with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
- focal = float(f.readline().split()[0])
- self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
- self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
-
- pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
- img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
-
- if self.split == 'train':
- pose_files = [x for x in pose_files if x.startswith('0_')]
- img_files = [x for x in img_files if x.startswith('0_')]
- elif self.split == 'val':
- pose_files = [x for x in pose_files if x.startswith('1_')]
- img_files = [x for x in img_files if x.startswith('1_')]
- elif self.split == 'test':
- test_pose_files = [x for x in pose_files if x.startswith('2_')]
- test_img_files = [x for x in img_files if x.startswith('2_')]
- if len(test_pose_files) == 0:
- test_pose_files = [x for x in pose_files if x.startswith('1_')]
- test_img_files = [x for x in img_files if x.startswith('1_')]
- pose_files = test_pose_files
- img_files = test_img_files
-
- # ray directions for all pixels, same for all images (same H, W, focal)
- self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
- self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
-
-
- self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
-
- self.poses = []
- self.all_rays = []
- self.all_rgbs = []
-
- assert len(img_files) == len(pose_files)
- for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
- image_path = os.path.join(self.root_dir, 'rgb', img_fname)
- img = Image.open(image_path)
- if self.downsample!=1.0:
- img = img.resize(self.img_wh, Image.LANCZOS)
- img = self.transform(img) # (4, h, w)
- img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
- if img.shape[-1]==4:
- img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
- self.all_rgbs += [img]
-
- c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
- c2w = torch.FloatTensor(c2w)
- self.poses.append(c2w) # C2W
- rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
- self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
-
- # w2c = torch.inverse(c2w)
- #
-
- self.poses = torch.stack(self.poses)
- if 'train' == self.split:
- if self.is_stack:
- self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
- else:
- self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
- self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
- else:
- self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
- self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
-
-
- def define_transforms(self):
- self.transform = T.ToTensor()
-
- def define_proj_mat(self):
- self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
-
- def world2ndc(self, points):
- device = points.device
- return (points - self.center.to(device)) / self.radius.to(device)
-
- def __len__(self):
- if self.split == 'train':
- return len(self.all_rays)
- return len(self.all_rgbs)
-
- def __getitem__(self, idx):
-
- if self.split == 'train': # use data in the buffers
- sample = {'rays': self.all_rays[idx],
- 'rgbs': self.all_rgbs[idx]}
-
- else: # create data for each image separately
-
- img = self.all_rgbs[idx]
- rays = self.all_rays[idx]
-
- sample = {'rays': rays,
- 'rgbs': img}
- return sample
|