|
- import os
- os.environ['TL_BACKEND'] = 'tensorflow'
- import logging
- import glob
- import numpy as np
- import tensorlayer as tl
- import tensorflow as tf
- # import tensorflow.compat.v1 as tf
- import argparse
- import importlib
- from tqdm import trange, tqdm
- from utils import grid #无框架
- import utils
- from ops.chamfer_dist import chamfer_dist
- from utils import color_space #无框架
-
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
-
- RANDOM_SEED = 42
- np.random.seed(RANDOM_SEED)
- # tf.set_random_seed(RANDOM_SEED)
-
- class Trainer():
- def __init__(self,checkpoint_dir,input_pattern,grid_steps,model,max_steps,grid_steps_factor=1.0,train_flag=True):
- self.rec_loss = tl.ops.constant(0.0)
- self.rep_loss = tl.ops.constant(0.0)
- self.input_pattern = input_pattern
- self.max_steps = max_steps
- self.x_tilde = None
- data_format = 'channels_last'
- self.points_axis = 0 if data_format == 'channels_last' else 1
- self.data = utils.pc_io.load_points([self.input_pattern]) #会正则化处理
- points = np.array(list(y[0][:, :6] for y in self.data)) #points.shape:(1, 23133, 6)
- GEO_DIM = 3
- self.x = points[0, :, :GEO_DIM].astype(np.float32) #(37183, 3)
- self.ori_colors = points[0, :, GEO_DIM:].astype(np.float32)
- self.n = len(points[0]) #等于点云文件点的个数
- grid_steps = grid.parse_grid_steps(grid_steps, self.n) * np.array([grid_steps_factor, grid_steps_factor, 1]) #输入的grid_steps: auto 输出的grid_steps: [218, 353, 1]因为计算过程有向下取整,所以218*353<=n,与n有关系
- self.grid_steps = grid_steps.astype(np.uint32)
- grid_values = grid.get_grid(self.grid_steps).astype(np.float32) #grid_values.shape: (13950, 3) 13950大小不固定,但是等于grid_steps.shape[0]*grid_steps.shape[1],值是0-1的等差数列
- self.checkpoint_dir = checkpoint_dir
- self.checkpoint_path = os.path.join(checkpoint_dir, 'best_model.npz')
- if train_flag:
- self.logger = self.getlogger(self.checkpoint_dir)
- self.logger.info(f'Grid steps: {self.grid_steps}')
-
- Model = getattr(importlib.import_module(model), 'Model') #数字开头的文件只能这么导入
- self.model = Model(points, grid_values)
- print("trainable_weights:")
- for w in self.model.trainable_weights:
- print(w.name, w.shape)
- if os.path.isfile(self.checkpoint_path): #文件存在
- tl.files.load_and_assign_npz(name=self.checkpoint_path, network=self.model)
- if train_flag:
- self.logger.info('Loaded checkpoint from ' + self.checkpoint_path)
-
- def loss_fn(self,x_tilde,ori_x):
- cdist, AB_dists, idx_fwd, BA_dists, _ = chamfer_dist(ori_x, x_tilde, k=0) #cdist是倒角距离,k=0找最近点
- _, AB_dists2, _, _, _ = chamfer_dist(x_tilde, x_tilde, k=1) #AB_dists2是倒角距离的其中一项,单向距离,k=1找第二近点
- self.x_tilde = x_tilde
- self.rec_loss = tl.ops.reduce_mean(cdist)
- self.rep_loss = tl.ops.reduce_variance(AB_dists2) #这个作为loss,可以让点云分布均匀点,不至于一些密一些疏
- loss = self.rec_loss + self.rep_loss
- return loss
-
- def getlogger(self,logdir): #不用改
- logger = logging.getLogger(__name__) #提供了应用程序可以直接使用的接口
- logger.setLevel(level = logging.INFO) #NOTSET < DEBUG < INFO < WARNING < ERROR < CRITICAL
- if not os.path.exists(logdir): os.makedirs(logdir)
- handler = logging.FileHandler(os.path.join(logdir, 'log.txt')) #将(logger创建的)日志记录发送到合适的目的输出
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') #决定日志记录的最终输出格式
- handler.setFormatter(formatter)
- console = logging.StreamHandler() #用于输出到控制台
- console.setLevel(logging.INFO)
- console.setFormatter(formatter)
- logger.addHandler(handler)
- logger.addHandler(console)
-
- return logger
-
- def eval_loss(self,x_tilde):
- _, _, idx_fwd, _, _ = chamfer_dist(self.x, x_tilde, k=0) #idx_fwd是对x的每个点,在x_tilde中的最近点的索引
- # Map colors forward
- num_points = x_tilde.shape[self.points_axis]
- colors_tilde = tl.ops.unsorted_segment_mean(self.ori_colors, idx_fwd, num_points) #通过坐标的对应关系直接映射颜色
- # Map colors backward
- ori_colors_tilde = tl.ops.gather(colors_tilde, idx_fwd, axis=self.points_axis) #再映射回去,正常情况应该和self.ori_colors一样,但多对一的情况就会不一样,与self.ori_colors差别越大,说明匹配效果越差
- bt709_rgb_to_yuv_m = tl.ops.constant(color_space.bt709_rgb_to_yuv_m, dtype=tl.float32)
- ori_colors_yuv = self.ori_colors @ bt709_rgb_to_yuv_m
- ori_colors_tilde_yuv = ori_colors_tilde @ bt709_rgb_to_yuv_m
-
- loss_weights = tl.ops.constant([0.8, 0.1, 0.1], dtype=tl.float32)
- col_losses = tl.ops.reduce_mean(
- tl.ops.squared_difference(ori_colors_yuv, ori_colors_tilde_yuv), axis=self.points_axis) #对应元素差的平方,求平均,最终结果shape:[3]
- col_loss = tf.tensordot(col_losses, loss_weights, 1)
- return col_loss
-
- def train(self):
- learning_rate = 1e-3
- optimizer = tl.optimizers.Adam(lr=learning_rate)
- net_with_loss = tl.models.WithLoss(self.model, self.loss_fn)
- train_weights = self.model.trainable_weights
- train_one_step = tl.models.TrainOneStep(net_with_loss, optimizer, train_weights)
-
- self.logger.info('Starting session')
- self.logger.info('Init session')
-
- step_val = 0
- first_step_val = step_val
- pbar = tqdm(total=self.max_steps)
- self.logger.info(f'Starting training with {self.input_pattern}')
- best_loss = float('inf')
- best_loss_step = step_val
- self.model.set_train()
- while step_val <= self.max_steps:
- pbar.update(step_val - pbar.n)
- save_interval = 1000
- summary_interval = 25*4
- save_model = step_val != first_step_val and step_val % save_interval == 0
- get_summary = step_val % summary_interval == 1 or save_model
- train_loss = train_one_step(self.x,self.x)
-
- save_path = os.path.join(self.checkpoint_dir, 'cur_model.npz')
- # Summaries
- if get_summary:
- col_loss = self.eval_loss(self.x_tilde)
- self.logger.info('step_val:' + str(step_val))
- self.logger.info('rec_loss:' + str(self.rec_loss.numpy()))
- self.logger.info('rep_loss:' + str(self.rep_loss.numpy()))
- self.logger.info('col_loss:' + str(col_loss.numpy()))
- self.logger.info('train_loss:' + str(train_loss.numpy()))
- pbar.set_description(f"rec_loss: {self.rec_loss.numpy():.3E}, rep_loss: {self.rep_loss.numpy():.3E},"
- + f" col_loss (no_grad): {col_loss.numpy():.3E}")
- # Early stopping
- if train_loss < best_loss:
- best_loss_step = step_val
- best_loss = train_loss
- save_path = os.path.join(self.checkpoint_dir, 'best_model.npz')
- save_model = True
- elif step_val - best_loss_step >= 250*2:
- tl.files.save_npz(self.model.all_weights, name=save_path)
- self.logger.info(f'Early stopping: model saved to {save_path}')
- break
- step_val += 1
- if save_model:
- tl.files.save_npz(self.model.all_weights, name=save_path)
- self.logger.info(f'Model saved to {save_path}')
- pbar.close()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- prog='11_train.py',
- description='Train network',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('input_pattern', help='Input pattern.')
- parser.add_argument('checkpoint_dir', help='Directory where to save/load model checkpoints.')
- parser.add_argument('--model', help='Model module.', required=True)
- parser.add_argument('--max_steps', type=int, default=100000, help='Train up to this number of steps.')
- parser.add_argument('--profiling', default=False, action='store_true', help='Enable profiling')
- parser.add_argument('--grid_steps', default='512,512,1')
- args = parser.parse_args()
-
- os.makedirs(os.path.split(args.checkpoint_dir)[0], exist_ok=True)
- trainer = Trainer(args.checkpoint_dir, args.input_pattern,args.grid_steps,args.model,args.max_steps,train_flag=True)
- trainer.train()
|