|
- # -*- coding: utf-8 -*-
- # @Time : 16/1/2019 5:04 PM
- # @Description :
- # @Author : li rui hui
- # @Email : ruihuili@gmail.com
- # Modified by Guocheng Qian
-
- import tensorflow as tf
- from Common.visu_utils import plot_pcd_three_views, point_cloud_three_views
- from Common.ops import add_scalar_summary, add_hist_summary
- from Upsampling.data_loader import Fetcher
- # from Upsampling.data_loader import PCdataset
- from Common import model_utils
- from Common import pc_util
- from Common.loss_utils import pc_distance, get_uniform_loss, get_repulsion_loss, generator_loss
- from tf_ops.sampling.tf_sampling import farthest_point_sample
- import logging
- import os
- from tqdm import tqdm
- from glob import glob
- import math
- from time import time
- from termcolor import colored
- import numpy as np
- from Common.model_utils import get_model_cls
- import tensorlayer as tl
- from tensorlayer.dataflow import Dataloader
-
-
-
-
- class Model(object):
- def __init__(self, opts):
- self.opts = opts
- self.global_step = 0
-
- def decayed_learning_rate(self, initial_learning_rate, step, decay_steps, decay_rate):
- return initial_learning_rate * decay_rate ** (step / decay_steps)
-
- def train(self):
- # data
- fetchworker = Fetcher(self.opts)
- fetchworker.start()
- # training_set = PCdataset(self.opts, is_training=True)
- # train_dataset = tl.dataflow.FromGenerator(training_set,output_types=(tl.float32,tl.float32,tl.float32))
- # train_dataloader = Dataloader(train_dataset, batch_size=self.opts.batch_size, shuffle=True)
-
-
- # optimizer
- self.G_optimizers = tl.optimizers.Adam(lr=self.opts.base_lr_g, beta_1=self.opts.beta)
-
-
- with open(os.path.join(self.opts.log_dir, 'args.txt'), 'w') as log:
- for arg in sorted(vars(self.opts)):
- log.write(arg + ': ' + str(getattr(self.opts, arg)) + '\n') # log of arguments
-
- start = time()
-
- #logging.info("========== Building Model ==========")
- model_cls = get_model_cls(self.opts.model)
- self.G = model_cls(self.opts, is_training=True, name='generator')
- #self.G = Generator(self.opts, self.is_training, name='generator')
-
-
- restore_epoch = 1
- if self.opts.restore:
- restore_epoch = 45
- restore_model_path = self.opts.log_dir + '/model-' + str(restore_epoch) + '.npz'
- print(restore_model_path)
- input = tl.layers.Input(shape=(64,256,3))
- self.G.init_build(input)
- tl.files.load_and_assign_npz(name=restore_model_path, network=self.G)
- restore_epoch += 1
- print("Resume training from Epoch % d" % restore_epoch)
- else:
- os.makedirs(os.path.join(self.opts.log_dir, 'plots'))
-
-
- for epoch in range(restore_epoch, self.opts.max_epochs+1):
- logging.info('**** EPOCH %03d ****\t' % (epoch))
- #print('========================', len(train_dataloader))
- for batch_idx in range(fetchworker.num_batches):
- #for batch_idx, batch_data in enumerate(train_dataloader):
-
- batch_input_x, batch_input_y, batch_radius = fetchworker.fetch()
- # batch_input_x, batch_input_y, batch_radius = batch_data
-
- batch_input_x = tf.convert_to_tensor(batch_input_x, dtype=tf.float32)
- batch_input_y = tf.convert_to_tensor(batch_input_y, dtype=tf.float32)
- batch_radius = tf.convert_to_tensor(batch_radius, dtype=tf.float32)
-
- with tf.GradientTape() as tape:
- # X -> Y
-
- self.G_y = self.G(batch_input_x)
-
- self.dis_loss = self.opts.fidelity_w * pc_distance(self.G_y, batch_input_y,
- radius=batch_radius,
- threshold=self.opts.cd_threshold,
- dis_type=self.opts.loss_type)
-
- self.pu_loss = self.dis_loss
- if self.opts.repulse:
- self.repulsion_loss = self.opts.repulsion_w * get_repulsion_loss(self.G_y)
- self.pu_loss += self.repulsion_loss
-
- if self.opts.uniform:
- self.uniform_loss = self.opts.uniform_w * get_uniform_loss(self.G_y)
- self.pu_loss += self.uniform_loss
-
- if self.opts.reg:
- self.pu_loss += tf.compat.v1.losses.get_regularization_loss()
-
-
- self.total_gen_loss = self.pu_loss
-
-
- # Update G network
- for i in range(self.opts.gen_update):
- # get previously generated images
- grads = tape.gradient(self.total_gen_loss, self.G.trainable_weights)
- self.G_optimizers.apply_gradients(zip(grads, self.G.trainable_weights))
-
-
- self.global_step += 1
- #print('----------------------', self.global_step)
-
- if self.global_step % self.opts.steps_per_print == 0:
- logging.info('-----------EPOCH %d Step %d:-------------' % (epoch, self.global_step))
- logging.info(' G_loss : {}'.format(self.total_gen_loss))
- if self.opts.use_gan:
- logging.info(' D_loss : {}'.format(self.D_loss))
- logging.info(' Time Cost : {}'.format(time() - start))
- start = time()
-
-
- learning_rate_g = tl.where(tl.greater_equal(self.global_step, self.opts.start_decay_step),
- self.decayed_learning_rate(self.opts.base_lr_g, self.global_step, self.opts.lr_decay_steps, self.opts.lr_decay_rate),
- self.opts.base_lr_g)
- #print(learning_rate_g)
- learning_rate_g = tl.maximum(learning_rate_g, self.opts.lr_clip)
- self.G_optimizers.learning_rate = learning_rate_g
-
-
- if (epoch % self.opts.epoch_per_save) == 0:
- tl.files.save_npz(self.G.all_weights, name= self.opts.log_dir + '/model-' + str(epoch) + '.npz')
- print(colored('Model saved at %s' % self.opts.log_dir, 'white', 'on_blue'))
-
- fetchworker.shutdown()
-
- def patch_prediction(self, patch_point):
- # normalize the point clouds
- patch_point, centroid, furthest_distance = pc_util.normalize_point_cloud(patch_point)
- patch_point = np.expand_dims(patch_point, axis=0)
- patch_point = tl.convert_to_tensor(patch_point, dtype = tl.float32)
-
- restore_epoch = 50
- restore_model_path = self.opts.log_dir + '/model-' + str(restore_epoch) + '.npz'
- #print('------Load pretrained model-------', restore_model_path)
- #input = tl.layers.Input(shape=(1,256,3))
- self.Gen.init_build(patch_point)
- tl.files.load_and_assign_npz(name=restore_model_path, network=self.Gen)
- #print("Test at Epoch % d" % restore_epoch)
-
- pred_pc = self.Gen(patch_point)
- for i in range(round(math.pow(self.opts.up_ratio, 1 / 4)) - 1):
- self.pred_pc =self.Gen(pred_pc)
-
- pred = [pred_pc.numpy()]
- # pred1 = self.sess.run([self.pred_pc], feed_dict={self.inputs: pred})
- pred = np.squeeze(centroid + pred * furthest_distance, axis=0)
- #print('===============pred_pc==================', pred.shape)
- return pred
-
- def pc_prediction(self, pc):
- # get patch seed from farthestsampling
- points = tf.convert_to_tensor(np.expand_dims(pc, axis=0), dtype=tf.float32)
- start = time()
- print('------------------patch_num_point:', self.opts.patch_num_point)
- seed1_num = int(pc.shape[0] / self.opts.patch_num_point * self.opts.patch_num_ratio)
-
- # FPS sampling
- seed = farthest_point_sample(seed1_num, points).numpy()[0]
- seed_list = seed[:seed1_num]
- print("farthest distance sampling cost", time() - start)
- print("number of patches: %d" % len(seed_list))
- input_list = []
- up_point_list = []
-
- patches = pc_util.extract_knn_patch(pc[np.asarray(seed_list), :], pc, self.opts.patch_num_point)
-
- patch_time = 0.
- for point in tqdm(patches, total=len(patches)):
- start = time()
- up_point = self.patch_prediction(point)
- end = time()
- patch_time += end-start
-
- up_point = np.squeeze(up_point, axis=0)
- input_list.append(point)
- up_point_list.append(up_point)
- return input_list, up_point_list, patch_time/len(patches)
-
-
- def test(self):
- # self.inputs = tf.placeholder(tf.float32, shape=[1, self.opts.patch_num_point, 3])
- # is_training = tf.placeholder_with_default(False, shape=[], name='is_training')
- self.pc_radius = tf.ones(self.opts.batch_size)
- model_cls = get_model_cls(self.opts.model)
- self.Gen = model_cls(self.opts, is_training=False, name='generator')
-
-
- samples = glob(self.opts.test_data)
- point = pc_util.load(samples[0])
- self.opts.num_point = point.shape[0]
- out_point_num = int(self.opts.num_point * self.opts.up_ratio)
-
- total_time = 0.
- for point_path in samples:
- print(point_path)
- pc = pc_util.load(point_path)[:, :3]
-
- pc, centroid, furthest_distance = pc_util.normalize_point_cloud(pc)
-
- if self.opts.test_jitter:
- pc = pc_util.jitter_perturbation_point_cloud(pc[np.newaxis, ...], sigma=self.opts.jitter_sigma,
- clip=self.opts.jitter_max)
- pc = pc[0, ...]
-
- input_list, pred_list, avg_patch_time = self.pc_prediction(pc)
-
- total_time += avg_patch_time
-
- pred_pc = np.concatenate(pred_list, axis=0)
- pred_pc = (pred_pc * furthest_distance) + centroid
-
- pred_pc = np.reshape(pred_pc, [-1, 3])
- path = os.path.join(self.opts.out_folder, point_path.split('/')[-1][:-4] + '.ply')
- idx = farthest_point_sample(out_point_num, pred_pc[np.newaxis, ...]).numpy()[0]
- pred_pc = pred_pc[idx, 0:3]
-
-
- np.savetxt(path[:-4] + '.xyz', pred_pc, fmt='%.6f')
-
- logging.info('Average Inference Time: {} ms'.format(total_time / len(samples) * 1000.))
|