|
- import os, sys, time, logging
- import numpy as np
- from dataprocess.inout_points import load_points, save_points, points2voxels, select_voxels #与框架无关
- from loss import get_bce_loss, get_classify_metrics #已改并测试
- import tensorlayer as tl
- #import tensorflow as tf
-
- class Trainer():
- def __init__(self, config, model, train_dataloader_len, test_dataloader_len):
- self.config = config
- self.logger = self.getlogger(config.logdir) #log是叠加的,不会覆盖原文件
- self.start = time.time()
- self.model = model
- self.best_bpp = None
- self.train_one_step = self.reset(self.model,self.config.init_ckpt,self.config.lr)
- self.logger.info(self.model)
- self.epoch = 0
- self.DISPLAY_STEP = 50 #100
- self.record_set = {'bpp_ae':[], 'bpp_hyper':[], 'bpp':[],'IoU':[]} #已改
- #self.best_loss = None
- self.train_info = {'train_zeros':0,'train_ones':0,'train_distortion':0,'train_loss':0,'train_bpp_ae':0,'train_bpp_hyper':0}
- self.x_tilde = None
- self.train_dataloader_len = train_dataloader_len
- self.test_dataloader_len = test_dataloader_len
-
- def getlogger(self, logdir): #不用改
- logger = logging.getLogger(__name__) #提供了应用程序可以直接使用的接口
- logger.setLevel(level = logging.INFO) #NOTSET < DEBUG < INFO < WARNING < ERROR < CRITICAL
- handler = logging.FileHandler(os.path.join(logdir, 'log.txt')) #将(logger创建的)日志记录发送到合适的目的输出
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') #决定日志记录的最终输出格式
- handler.setFormatter(formatter)
- console = logging.StreamHandler() #用于输出到控制台
- console.setLevel(logging.INFO)
- console.setFormatter(formatter)
- logger.addHandler(handler)
- logger.addHandler(console)
-
- return logger
-
- '''def save_model(self,IoU,loss,global_step=None): #已改
- store_flag = False
- if global_step == None: #epoch结束的保存
- save_dir = os.path.join(self.config.ckptdir, 'epoch_' + str(self.epoch) + '.npz')
- store_flag = True
- self.logger.info('Saved model on epoch: %i, for epoch finished.' % (self.epoch))
- else:
- save_dir = os.path.join(self.config.ckptdir, 'epoch_' + str(self.epoch) + '_' + str(int(global_step)) + '.npz')
- if not store_flag:
- if self.best_IoU is None or self.best_IoU < IoU:
- self.best_IoU = IoU
- store_flag = True
- self.logger.info('Saved model on epoch: %i, global_step: %i, for better IoU.' % (self.epoch,global_step))
- if self.best_loss is None or self.best_loss > loss:
- self.best_loss = loss
- store_flag = True
- self.logger.info('Saved model on epoch: %i, global_step: %i, for better loss.' % (self.epoch,global_step))
- if store_flag:
- tl.files.save_npz(self.model.all_weights, name=save_dir)
- self.config.init_ckpt = save_dir
- return
- '''
- def save_model(self,bpp=0.0,global_step=None): #已改
- store_flag = False
- if global_step == None: #epoch结束的保存
- save_dir = os.path.join(self.config.ckptdir, 'epoch_' + str(self.epoch) + '.npz')
- store_flag = True
- self.logger.info('Saved model on epoch: %i, for epoch finished.' % (self.epoch))
- else:
- save_dir = os.path.join(self.config.ckptdir, 'epoch_' + str(self.epoch) + '_' + str(int(global_step)) + '.npz')
- if not store_flag:
- if self.best_bpp is None or self.best_bpp > bpp:
- self.best_bpp = bpp
- store_flag = True
- self.logger.info('Saved model on epoch: %i, global_step: %i, for better bpp.' % (self.epoch,global_step))
- if store_flag:
- tl.files.save_npz(self.model.all_weights, name=save_dir)
- self.config.init_ckpt = save_dir
- return
-
- def record(self, main_tag, global_step): #已改
- self.logger.info('='*10+main_tag + ' Epoch ' + str(self.epoch) + ' Step: ' + str(global_step))
- for k, v in self.record_set.items():
- self.logger.info(k+': '+str(np.round(v, 4).tolist()))
- for k in self.record_set.keys():
- self.record_set[k] = []
- return
-
- #已改test
- def test(self, dataloader, main_tag='Test'):
- bpps_ae = 0.
- bpps_hyper = 0.
- IoUs = 0.
- self.model.set_eval()
- self.logger.info('Testing Files length:' + str(self.test_dataloader_len))
- for _, (points,_) in enumerate(dataloader): #points是单纯点的集合
- # data
- #x_np = points2voxels(points,64).astype('float32') #转成voxel
- x = points #tl.convert_to_tensor(x_np,dtype='float32') #(8, 64, 64, 64, 1)
- # # Forward.
- out_set = self.model(x, training=False) #在GPU
- # loss
- #bce, bce_list = 0, []
- num_points = tl.ops.ReduceSum()(tl.ops.cast(tl.ops.greater(tl.ops.ReduceSum(-1)(x), 0), tl.float32)) #x各维度点的总个数
- train_bpp_ae = tl.ops.ReduceSum()(tl.ops.log(out_set['likelihoods'])) / (-np.log(2) * num_points)
- train_bpp_hyper = tl.ops.ReduceSum()(tl.ops.log(out_set['likelihoods_hyper'])) / (-np.log(2) * num_points)
-
- points_nums = tl.ops.cast(tl.ops.ReduceSum(axis=(1,2,3,4))(x), 'int32')
- x_tilde = out_set['x_tilde'].numpy()
- output = select_voxels(x_tilde, points_nums, 1.0) #(8, 64, 64, 64, 1)numpy数组
- _, _, IoU = get_classify_metrics(output, x)
-
- bpps_ae = bpps_ae + train_bpp_ae.numpy() #衡量码率
- bpps_hyper = bpps_hyper + train_bpp_hyper.numpy() #衡量码率
- IoUs = IoUs + IoU.numpy() #衡量准确率
-
- bpps_ae = bpps_ae / self.test_dataloader_len
- bpps_hyper = bpps_hyper / self.test_dataloader_len
- IoUs = IoUs / self.test_dataloader_len
- # record
- self.record_set['bpp_ae'].append(bpps_ae) #.item()将单元素tensor转成scalar
- self.record_set['bpp_hyper'].append(bpps_hyper)
- self.record_set['bpp'].append(bpps_ae+bpps_hyper)
- self.record_set['IoU'].append(IoUs)
- self.record(main_tag=main_tag, global_step=0)
- return
-
- def loss_fn(self,out_set,ori_x): #已改
- num_points = tl.ops.ReduceSum()(tl.ops.cast(tl.ops.greater(tl.ops.ReduceSum(-1)(ori_x), 0), tl.float32)) #x各维度点的总个数
- train_bpp_ae = tl.ops.ReduceSum()(tl.ops.log(out_set['likelihoods'])) / (-np.log(2) * num_points)
- train_bpp_hyper = tl.ops.ReduceSum()(tl.ops.log(out_set['likelihoods_hyper'])) / (-np.log(2) * num_points)
- train_zeros, train_ones = get_bce_loss(out_set['x_tilde'], ori_x)
- train_distortion = self.config.beta * train_zeros + 1.0 * train_ones
- train_loss = self.config.alpha * train_distortion + self.config.delta * train_bpp_ae + self.config.gamma * train_bpp_hyper
- self.train_info['train_zeros'] = train_zeros.numpy()
- self.train_info['train_ones'] = train_ones.numpy()
- self.train_info['train_distortion'] = train_distortion.numpy()
- self.train_info['train_loss'] = train_loss.numpy()
- self.train_info['train_bpp_ae'] = train_bpp_ae.numpy()
- self.train_info['train_bpp_hyper'] = train_bpp_hyper.numpy()
- self.x_tilde = out_set['x_tilde'].numpy() #(8, 64, 64, 64, 1)
- return train_loss
-
- def reset(self,model,ckpt_dir,learning_rate): #重新加载模型,更新lr
- if ckpt_dir is not '':
- tl.files.load_and_assign_npz(name=ckpt_dir, network=model) #name='model.npz'
- self.logger.info('Loaded checkpoint from ' + ckpt_dir)
- else:
- self.logger.info('Training from scratch')
- self.model = model
- optimizer = tl.optimizers.Adam(lr=learning_rate)
- self.config.lr = learning_rate
- net_with_loss = tl.models.WithLoss(self.model, self.loss_fn)
- train_weights = self.model.trainable_weights
- train_one_step = tl.models.TrainOneStep(net_with_loss, optimizer, train_weights)
- if self.best_bpp is not None:
- if self.best_bpp < 0.47:
- self.best_bpp = 0.47 #防止前面有些bpp过小,后面的模型存不下来
- return train_one_step
-
- #已改train
- def train(self, dataloader):
- self.logger.info('='*40+'\n'+'Training Epoch: ' + str(self.epoch))
- self.logger.info('alpha:' + str(round(self.config.alpha,2)) + '\tbeta:' + str(round(self.config.beta,2)))
- self.logger.info('LR:' + str(np.round(self.config.lr, 6).tolist()))
- # dataloader
- self.logger.info('Training Files length:' + str(self.train_dataloader_len))
- train_bpp_ae_sum = 0. #用于record log
- train_bpp_hyper_sum = 0.
- train_IoU_sum = 0.
- num = 0.
- self.model.set_train()
-
- for batch_step, (points,_) in enumerate(dataloader): #循环一次一个batch,points是单纯点的集合
- # data
- #x_np = points2voxels(points,64).astype('float32') #转成voxel
- x = points #tl.convert_to_tensor(x_np,dtype='float32') #(8, 64, 64, 64, 1)
- train_loss = self.train_one_step(x,x)
-
- # post-process: classification.
- points_nums = tl.ops.cast(tl.ops.ReduceSum(axis=(1,2,3,4))(x), 'int32')
- output = select_voxels(self.x_tilde, points_nums, 1.0) #(8, 64, 64, 64, 1)numpy数组
- train_bpp_ae_sum += self.train_info['train_bpp_ae']
- train_bpp_hyper_sum += self.train_info['train_bpp_hyper']
- _, _, IoU = get_classify_metrics(output, x)
- train_IoU_sum += IoU.numpy()
- num += 1
-
- # Display
- if (batch_step + 1) % self.DISPLAY_STEP == 0: #每100 step record一次
- train_bpp_ae_sum /= num
- train_bpp_hyper_sum /= num
- train_IoU_sum /= num
- # Save checkpoints.
- self.save_model(bpp=train_bpp_ae_sum+train_bpp_hyper_sum, global_step=self.epoch*self.train_dataloader_len+batch_step)
-
- print("Iteration:{0:}".format(batch_step))
- print("Bpps: {0:.4f} + {1:.4f}".format(train_bpp_ae_sum, train_bpp_hyper_sum))
- print("IoU: ", train_IoU_sum)
- self.logger.info('Running time:(mins): '+str(round((time.time()-self.start)/60.)))
-
- # record
- self.record_set['bpp_ae'].append(train_bpp_ae_sum)
- self.record_set['bpp_hyper'].append(train_bpp_hyper_sum)
- self.record_set['bpp'].append(train_bpp_ae_sum+train_bpp_hyper_sum)
- self.record_set['IoU'].append(train_IoU_sum)
- self.record(main_tag='Train', global_step=self.epoch*self.train_dataloader_len+batch_step)
-
- self.logger.info('train_zeros:' + str(self.train_info['train_zeros']))
- self.logger.info('train_ones:' + str(self.train_info['train_ones']))
- self.logger.info('train_distortion:' + str(self.train_info['train_distortion']))
- self.logger.info('train_loss:' + str(self.train_info['train_loss']))
-
- num = 0.
- train_bpp_ae_sum = 0.
- train_bpp_hyper_sum = 0.
- train_IoU_sum = 0.
- self.save_model()
- self.epoch += 1
- return
|