|
- import h5py
- import torch.utils.data as data
- import os, sys
- sys.path.append("../")
- import numpy as np
- import utils.data_util as utils
- from torchvision import transforms
-
-
- def load_h5_data(h5_file_path='', up_ratio=4, skip_rate=1, npoint=256, use_random=True, use_norm=True):
-
- num_4X_point = int(npoint*4)
- num_out_point = int(npoint*up_ratio)
-
- print("h5_file_path : ", h5_file_path)
- if use_random:
- print("use random for input")
- f = h5py.File(h5_file_path)
- input = f['poisson_%d'%num_4X_point][:] # 1024 -> 256
- gt = f['poisson_%d'%num_out_point][:] # 1024
- else:
- print("Do not random for input")
- f = h5py.File(h5_file_path)
- input = f['poisson_%d' % npoint][:]
- gt = f['poisson_%d' % num_out_point][:]
-
- assert len(input) == len(gt)
-
- if use_norm:
- print("Normalization the data")
- data_radius = np.ones(shape=(len(input)))
- centroid = np.mean(gt[:, :, 0:3], axis=1, keepdims=True)
- gt[:, :, 0:3] = gt[:, :, 0:3] - centroid
- furthest_distance = np.amax(np.sqrt(np.sum(gt[:, :, 0:3] ** 2, axis=-1)), axis=1, keepdims=True)
- gt[:, :, 0:3] = gt[:, :, 0:3] / np.expand_dims(furthest_distance, axis=-1)
- input[:, :, 0:3] = input[:, :, 0:3] - centroid
- input[:, :, 0:3] = input[:, :, 0:3] / np.expand_dims(furthest_distance, axis=-1)
-
- input = input[::skip_rate]
- gt = gt[::skip_rate]
- data_radius = data_radius[::skip_rate]
- print("total %d samples" % (len(input)))
- return input, gt, data_radius
-
- class PUGAN_Dataset(data.Dataset):
- def __init__(self, h5_file_path='./PUGAN_poisson_256_poisson_1024.h5',
- skip_rate=1, npoint=256, use_random=True, use_norm=True):
- super().__init__()
-
- self.h5_file_path = h5_file_path
- self.npoint = npoint
- self.use_random = use_random
- self.use_norm = use_norm
-
- self.input, self.gt, self.radius = load_h5_data(self.h5_file_path, npoint=self.npoint, use_random=self.use_random, use_norm=self.use_norm)
- #print('+++++++++++++++++++',self.input.shape)
-
- self.data_npoint = self.input.shape[1] # 1024
- assert self.input.shape[1] == 1024
-
-
- def __len__(self):
- return self.input.shape[0] # (24000, N, 3)
-
- def __getitem__(self, index):
- input_data = self.input[index]
- gt_data = self.gt[index]
- radius_data = np.array([self.radius[index]])
-
- sample_idx = utils.nonuniform_sampling(self.data_npoint, sample_num=self.npoint) #1024->256
- input_data = input_data[sample_idx, :]
-
- if self.use_norm:
- # for data aug
- input_data, gt_data = utils.rotate_point_cloud_and_gt(input_data, gt_data)
- input_data, gt_data, scale = utils.random_scale_point_cloud_and_gt(input_data, gt_data,
- scale_low=0.8, scale_high=1.2)
- input_data, gt_data = utils.shift_point_cloud_and_gt(input_data, gt_data, shift_range=0.1)
- radius_data = radius_data * scale
-
- #print('+++++++++++++++++++',input_data.shape)
- return input_data, gt_data, radius_data
-
-
-
|