|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Author: Yue Wang
- @Contact: yuewangx@mit.edu
- @File: main_cls.py
- @Time: 2018/10/13 10:39 PM
-
- Modified by
- @Author: An Tao
- @Contact: ta19@mails.tsinghua.edu.cn
- @Time: 2019/12/30 9:32 PM
-
- Modified by
- @Author: Dinghao Yang
- @Contact: dinghaoyang@gmail,cin
- @Time: 2020/9/28 7:29 PM
-
- Modified by
- @Author: Yu Deng
- @Contact: dengy02@pcl.ac.cn
- @Time: 2022/7/8 14:20 PM
- """
-
-
- from __future__ import print_function
- import os
- import sys
- import argparse
- from statistics import mode
- from traceback import print_tb
- import numpy as np
- import sklearn.metrics as metrics
- import platform
- import random
- import time
- import datetime
- import tensorflow as tf
-
- from data import ModelNet40, ModelNet40_LLE
- from model import PointManifold_LLE, PointManifold_NNML
- from util import cal_loss, IOStream, Logger
-
- import logging
- import pynvml # pip install nvidia-ml-py
-
- ## single gpu
- # os.environ['CUDA_VISIBLE_DEVICES'] = "1"
- # mutl gpus
- pynvml.nvmlInit() # 初始化
- # print(pynvml.nvmlDeviceGetCount())
- nvml_count = pynvml.nvmlDeviceGetCount() #显示有几块GPU
- print("nvml_count", nvml_count)
- nvml_count_str = ",".join([str(i) for i in range(nvml_count)])
- os.environ['CUDA_VISIBLE_DEVICES'] = nvml_count_str # nvml_count_str
-
-
- def log(args, log_dir):
- # logging模块由logger,handler,filter,fomatter四个部分组成
- # 获取一个logger对象
- logger = logging.getLogger("Pointmanifold log")
- # 设置日志输出等级
- logger.setLevel(logging.DEBUG)
- # 创建一个文件的handler
- os.makedirs(f"checkpoints/{args.model}", exist_ok=True)
- if not args.eval:
- f_handler = logging.FileHandler(
- f"{log_dir}/Pointmanifold_train&val_{num}.log")
- else:
- f_handler = logging.FileHandler(
- f"{log_dir}/Pointmanifold_test_{num}.log")
- f_handler.setLevel(logging.INFO)
- # 创建一个控制台的handler
- c_handler = logging.StreamHandler()
- c_handler.setLevel(logging.DEBUG)
- # 设置日志的输出格式
- fmt = logging.Formatter("%(asctime)s-%(name)s-%(levelname)s-%(message)s")
- # 给handler绑定一个fomatter类
- f_handler.setFormatter(fmt)
- c_handler.setFormatter(fmt)
- # 绑定一个handler
- logger.addHandler(f_handler)
- logger.addHandler(c_handler)
-
- return logger
-
-
- def set_seed(seed_value):
- random.seed(seed_value)
- np.random.seed(seed_value)
- tf.random.set_seed(seed_value)
- tf.experimental.numpy.random.seed(seed_value)
- # tf.keras.utils.set_random_seed(seed_value)
-
-
- def cal_loss_avg_train(labels, logits, smoothing=True):
- losses = cal_loss(labels=labels, logits=logits, smoothing=smoothing)
- print("losses", losses)
- # 10. is not necessary, but to make loss look like normal
- loss = 1. * losses / args.batch_size # 1. * loss / batch_size_train
- return loss # tf.nn.compute_average_loss(loss, global_batch_size=batch_size_train)
-
-
- def train_step(model, data, labels, train_accuracy):
- # inint dtype:: data: float32, labels:int64
- labels = tf.squeeze(labels) # labels.squeeze()
- labels = tf.cast(labels, tf.int32)
- # print("labels", labels)
- # print("type labels", type(labels))
- data = tf.transpose(data, [0, 2, 1]) # data.permute(0, 2, 1)
- batch_size = data.shape[0]
- with tf.GradientTape() as tape:
- logits = model(data)
- logits = tf.cast(logits, dtype='float32')
- loss = cal_loss_avg_train(labels=labels, logits=logits, smoothing=True)
- # 将损失值乘以损失标度值
- scaled_loss = opt.get_scaled_loss(loss)
-
- ## 为了防止梯度发生下溢,必须使用这些函数(get_scaled_loss、get_unscaled_gradients)。
- ## 随后,如果全部没有出现 Inf 或 NaN 值,则 LossScaleOptimizer.apply_gradients 会应用这些梯度。
- ## 它还会更新损失标度,如果梯度出现 Inf 或 NaN 值,则会将其减半,而如果出现零值,则会增大损失标度。
- scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)
- # 获取一系列放大的梯度作为输入,并将每一个梯度除以损失标度,从而将其缩小为实际值
- gradients = opt.get_unscaled_gradients(scaled_gradients)
- opt.apply_gradients(zip(gradients, model.trainable_variables))
-
- # grads = tape.gradient(loss, model.trainable_variables) # model.trainable_weights
- # # clip_grads = [tf.clip_by_value(grad, -1.0, 1.0) for grad in grads] # 限制梯度大小,加快收敛
- # ## Run one step of gradient descent by updating the value of the variables to minimize the loss.
- # opt.apply_gradients(zip(grads, model.trainable_variables)) # model.trainable_weights
-
- labels_list = strategy.experimental_local_results(labels)
- print("labels_list", labels_list)
- preds = tf.argmax(logits, axis=1)
- print("preds", preds)
- print("preds", strategy.experimental_local_results(preds))
- train_accuracy.update_state(y_true=labels, y_pred=preds)
- print("train_accuracy in train step", train_accuracy)
-
- return loss
-
-
- @tf.function
- def distributed_strategy_train(model, data, labels, train_accuracy):
- # per_replica_loss, per_replica_preds, per_replica_labels
- losses = strategy.run(train_step, args=(model, data, labels, train_accuracy))
- distributed_losses = strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None) # "SUM"
- print("loss, train_accuracy: ", losses, train_accuracy)
-
- return distributed_losses
-
-
- def train(args, model, opt, ckpt_dir):
- """
- train and val.
- """
- # default:中断后继续训练
- if args.resume and os.listdir(ckpt_dir):
- # update checkpoint
- print("resume train")
- step = tf.Variable(0, name="step")
- checkpoint = tf.train.Checkpoint(step=step, net=model, opt=opt)
- latest_ckpt = tf.train.latest_checkpoint(ckpt_dir)
- ## restore 返回一个具有可选断言的状态对象。
- ## 在新的 Checkpoint 中创建的所有对象都已恢复,因此 status.assert_existing_objects_matched 通过。
- checkpoint.restore(latest_ckpt)
- staus = checkpoint.restore(latest_ckpt)
- ## 检查点中有许多不匹配的对象,包括层的内核和优化器的变量。
- # status.assert_consumed() 仅在检查点和程序完全匹配时通过,并在此处抛出异常。
- ## 关于不完整的检查点还原的静默警告。 当“检查点”对象被删除时(通常在程序关闭时),检查点文件或对象的未使用部分会打印警告。
- staus.expect_partial()
- if latest_ckpt:
- logger.info('loading checkpoint from '+ latest_ckpt)
- else:
- step = tf.Variable(0, name="step")
- checkpoint = tf.train.Checkpoint(step=step, net=model, opt=opt)
- ## 检索模型的初始权重。这样可以通过加载权重来从头开始训练。
- initial_weights = model.get_weights()
- ## 加载模型的初始权重,以便可以从头开始重新训练
- model.set_weights(initial_weights)
- # logger.info(str(initial_weights))
-
- train_accuracy = tf.keras.metrics.CategoricalAccuracy()
- train_dist_dataset = strategy.experimental_distribute_dataset(train_loader.data)
- # train_dist_dataset = iter(train_dist_dataset)
- # print("train_dist_dataset:", len(next(train_dist_dataset)))
-
- best_test_acc = 0
- print("Staring train")
- for epoch in range(args.epochs):
- ####################
- # Training
- ####################
- # init parameters
- epoch_train_start_time = time.time()
- train_loss = 0.0
- count = 0.0
- # train_pred = []
- # train_true = []
- iters_train = 0.
- iters_train_time = 0.
-
- # train_dist_dataset = iter(train_dist_dataset)
- for iters, (data, labels) in enumerate(train_dist_dataset):
- iter_train_start_time = time.time()
-
- loss = distributed_strategy_train(model, data, labels, train_accuracy)
-
- print(f"Epoch: {epoch} \t iters: {iters} \t train loss: {loss} \t train accuracy: {train_accuracy.result().numpy()}")
-
- # preds = tf.argmax(logits, axis=1)
- count += args.batch_size
- train_loss += loss.numpy() * args.batch_size
- # train_true.append(labels.numpy())
- # train_pred.append(preds.numpy())
-
- iters_train += 1
- iters_train_time += time.time() - iter_train_start_time
-
- # break
- logger.info('The avg time of {} epoch which has {} iters is: {:.3f}'.format(epoch, iters_train, iters_train_time/(iters_train+0.001)))
-
- # train_true = np.concatenate(train_true)
- # train_pred = np.concatenate(train_pred)
-
- outstr = 'Train epoch: %d, \
- Loss: %.6f, \
- Train acc: %.6f' % (epoch,
- train_loss*1.0/count,
- float(train_accuracy.result().numpy())
- )
-
- logger.info(outstr)
- epoch_train_time = time.time() - epoch_train_start_time
- logger.info('Epoch {} train use time is: {:.3f}'.format(epoch, epoch_train_time))
- print(outstr)
-
- # tensorboard
- with train_summary_writer.as_default():
- tf.summary.scalar('train loss', train_loss*1.0/count, step=epoch)
- tf.summary.scalar('train accuracy', train_accuracy.result().numpy(), step=epoch)
-
- train_accuracy.reset_states()
-
- # break
-
- ####################
- # testing
- ####################
- # init parameters
- test_loss = 0.0
- count = 0.0
- # model.eval()
- test_pred = []
- test_true = []
- iters_val = 0.
- iters_val_time = 0.
- epoch_val_start_time = time.time()
- for iters, (data, labels) in enumerate(test_loader.data):
- iter_val_start_time = time.time()
- labels = tf.squeeze(labels) # labels.squeeze()
- data = tf.transpose(data, [0, 2, 1]) # data.permute(0, 2, 1)
- batch_size = data.shape[0]
-
- logits = model(data)
- logits = tf.cast(logits, dtype='float32')
- loss = cal_loss(labels=labels, logits=logits, smoothing=True)
-
- print(f"Epoch: {epoch} \t iters: {iters} \t test loss: {loss}")
-
- preds = tf.argmax(logits, axis=1) # preds = logits.max(dim=1)[1]
- count += batch_size
- test_loss += loss.numpy() * batch_size
- test_true.append(labels.numpy())
- test_pred.append(preds.numpy())
-
- iters_val += 1
- iters_val_time += time.time() - iter_val_start_time
-
- logger.info('The avg time of epoch {} has iters {} is: {:.3f}'.format(epoch, iters_val, iters_val_time/(iters_val+0.001)))
-
- test_true = np.concatenate(test_true)
- test_pred = np.concatenate(test_pred)
-
- test_acc = metrics.accuracy_score(test_true, test_pred)
- avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
- outstr = 'Test epoch: %d, \
- loss: %.6f, \
- test acc: %.6f, \
- test avg acc: %.6f' % (epoch,
- test_loss*1.0/count,
- test_acc,
- avg_per_class_acc)
-
- epoch_val_time = time.time() - epoch_val_start_time
- logger.info('Epoch {} val use time is: {:.3f}'.format(epoch, epoch_val_time))
- logger.info(outstr)
- print(outstr)
-
- # tensorboard
- with test_summary_writer.as_default():
- tf.summary.scalar('test loss', test_loss*1.0/count, step=epoch)
- tf.summary.scalar('test acc', test_acc, step=epoch)
- tf.summary.scalar('test avg acc', avg_per_class_acc, step=epoch)
-
- # checkpoint.assert_consumed()
- # checkpoint.expect_partial()
- if test_acc >= best_test_acc:
- best_test_acc = test_acc
- checkpoint.save(f"{ckpt_dir}{args.model}_weights") # save_format="h5"
- model.save_weights(f"{ckpt_dir}{args.model}_weights.ckpt", save_format="h5") # default: save_format="h5"
-
-
- def test(model):
- # update checkpoint
- step = tf.Variable(0, name="step")
- checkpoint = tf.train.Checkpoint(step=step, net=model, opt=opt)
- # if os.listdir(ckpt_dir) != []:
- latest_ckpt = tf.train.latest_checkpoint(ckpt_dir)
- checkpoint.restore(latest_ckpt)
- staus = checkpoint.restore(latest_ckpt)
- staus.expect_partial()
- if latest_ckpt:
- logger.info('loading checkpoint from '+ latest_ckpt)
- # model.load_weights(f"{ckpt_dir}{args.model}_weights.ckpt")
-
- test_acc = 0.0
- test_true = []
- test_pred = []
-
- iters_test = 0
- iters_test_time = 0.
- for data, labels in test_loader.data:
- iter_test_start_time = time.time()
- labels = tf.squeeze(labels) # labels.squeeze()
- data = tf.transpose(data, [0, 2, 1]) # data.permute(0, 2, 1)
- logits = model(data)
- logits = tf.cast(logits, dtype='float32')
- preds = tf.argmax(logits, axis=1) # preds = logits.max(dim=1)[1]
- test_true.append(labels.numpy())
- test_pred.append(preds.numpy())
-
- iters_test += 1
- iters_test_time += time.time() - iter_test_start_time
-
- logger.info('The avg time of test is: {:.3f}'.format(iters_test_time/(iters_test+0.001)))
-
- test_true = np.concatenate(test_true)
- test_pred = np.concatenate(test_pred)
- test_acc = metrics.accuracy_score(test_true, test_pred)
- avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred)
- outstr = 'Test:: test acc: %.6f, test avg acc: %.6f'%(test_acc, avg_per_class_acc)
- logger.info(outstr)
- print(outstr)
-
-
- if __name__ == "__main__":
- # Training settings
- parser = argparse.ArgumentParser(description='Point Cloud Recognition')
- parser.add_argument('--exp_name', type=str, default='pm', metavar='N',
- help='Name of the experiment')
- parser.add_argument('--model', type=str, default='pointmanifold_nnml', metavar='N',
- choices=['pointmanifold_lle', 'pointmanifold_nnml'],
- help='Model to use, [pointmanifold_lle, pointmanifold_nnml]')
- parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N',
- choices=['modelnet40'])
- parser.add_argument('--batch_size', type=int, default=16, metavar='batch_size',
- help='Size of batch)') # 32
- parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
- help='Size of batch)')
- parser.add_argument('--epochs', type=int, default=350, metavar='N',
- help='number of episode to train ') # 300, 250, 350
- parser.add_argument('--use_sgd', type=bool, default=True,
- help='Use SGD')
- parser.add_argument('--opt', type=str, default="sgd", metavar='N',
- choices=['sgd', 'adam', 'nadam'],
- help='opt to use [sgd, adam, adabound]')
- parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
- help='learning rate (default: 0.001, 0.1 if using sgd)')
- parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
- help='SGD momentum (default: 0.9)')
- parser.add_argument('--scheduler', type=str, default='cos', metavar='N',
- choices=['cos', 'cos_restarts', 'piecewise', 'inverse', 'exponential', 'natural_exponential', 'polynomial'],
- help='Scheduler to use, [cos, step]')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--eval', type=bool, default=False,
- help='evaluate the model')
- parser.add_argument('--num_points', type=int, default=1024,
- help='num of points to use')
- parser.add_argument('--dropout', type=float, default=0.5,
- help='initial dropout rate')
- parser.add_argument('--emb_dims', type=int, default=1024, metavar='N',
- help='Dimension of embeddings')
- parser.add_argument('--k', type=int, default=20, metavar='N',
- help='Num of nearest neighbors to use')
- parser.add_argument('--model_path', type=str, default='', metavar='N',
- help='Pretrained model path')
- parser.add_argument('--hyper_times', type=int, default=1, metavar='N',
- help='The time of model size, in paper it means t')
- parser.add_argument('--resume', type=bool, default=True,
- help='Continue training after interruption')
- args = parser.parse_args()
-
- # 混合精度
- policy = tf.keras.mixed_precision.Policy('mixed_float16')
- tf.keras.mixed_precision.set_global_policy(policy)
- # Equivalent to the two lines above
- # tf.keras.mixed_precision.set_global_policy('mixed_float16')
- print('Compute dtype: %s' % policy.compute_dtype)
- print('Variable dtype: %s' % policy.variable_dtype)
-
- # mutl training, you can change train_nums, like 1, 2, 3, ...
- num = f'{args.batch_size}_{args.opt}_{args.scheduler}_train_nums{1}'
- ckpt_dir = 'checkpoints/%s/%s/' % (args.model, num)
- log_dir = 'logs/%s/%s' % (args.model, num)
- os.makedirs(ckpt_dir, exist_ok=True)
- os.makedirs(log_dir, exist_ok=True)
-
- # create file writer for tensorboard
- current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
- train_log_dir = f'{log_dir}/TensorBoard/{current_time}/train'
- test_log_dir = f'{log_dir}/TensorBoard/{current_time}/test'
- train_summary_writer = tf.summary.create_file_writer(train_log_dir)
- test_summary_writer = tf.summary.create_file_writer(test_log_dir)
-
- # set seed
- set_seed(args.seed)
-
- # set log
- # io = IOStream('checkpoints/' + args.model + '/log.log')
- # io.cprint(str(args))
- logger = log(args, log_dir)
- logger.info(str(args))
- if not args.eval:
- sys.stdout = Logger(os.path.join(log_dir, 'log_train.txt'))
- else:
- sys.stdout = Logger(os.path.join(log_dir, 'log_test.txt'))
- print("==========\nArgs:{}\n==========".format(args))
-
- cuda = tf.test.is_built_with_cuda()
- if cuda:
- logger.info('GPU is available !')
- ## 设置 GPU 显存占用为按需分配
- # tf.config.list_physical_devices("GPU")
- physical_devices = tf.config.list_physical_devices('GPU') # tf.config.experimental.list_physical_devices('GPU')
- assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
- try:
- for i in range(nvml_count):
- config = tf.config.experimental.set_memory_growth(physical_devices[i], True)
- except:
- # Invalid device or cannot modify virtual devices once initialized.
- pass
- else:
- logger.info('Using CPU')
-
- # 单机多卡 MirroredStrategy
- strategy = tf.distribute.MirroredStrategy()
- ## 指定设备
- # strategy = tf.distribute.MirroredStrategy(devices=['/gpu:0'])
-
- batch_size_per_replica_train = args.batch_size
- # batch_size_per_replica_test = args.test_batch_size
-
- # 定义策略
- strategy = tf.distribute.MirroredStrategy()
- print("strategy.num_replicas_in_sync", strategy.num_replicas_in_sync)
-
- print("gpu number:{}".format(strategy.num_replicas_in_sync))
- batch_size_train = batch_size_per_replica_train * strategy.num_replicas_in_sync
- # batch_size_test = batch_size_per_replica_test* strategy.num_replicas_in_sync
- batch_size_test = args.test_batch_size
-
- # 使用策略
- with strategy.scope():
- # load dataset ModelNet40 and model
- if args.model == 'pointmanifold_lle':
- # load dataset ModelNet40
- train_loader = ModelNet40_LLE(partition='train', batch_size=batch_size_train, num_points=args.num_points)
- test_loader = ModelNet40_LLE(partition='test', batch_size=batch_size_test, num_points=args.num_points)
- # load model
- model = PointManifold_LLE(args)
- # train model build
- if not args.eval:
- # print model struction
- model.build(input_shape=(batch_size_train, 5, args.num_points)) # 例子:[4, 5, 1024]
- else:
- # print model struction
- model.build(input_shape=(batch_size_test, 5, args.num_points)) # 例子:[4, 5, 1024]
- # model.summary()
- elif args.model == 'pointmanifold_nnml':
- # load dataset ModelNet40
- train_loader = ModelNet40(partition='train', batch_size=batch_size_train, num_points=args.num_points)
- test_loader = ModelNet40(partition='test', batch_size=batch_size_test, num_points=args.num_points)
- # load model
- model = PointManifold_NNML(args)
- # train model build
- if not args.eval:
- # print model struction
- model.build(input_shape=(batch_size_train, 3, args.num_points)) # 例子:[4, 3, 1024]
- else:
- model.build(input_shape=(batch_size_test, 3, args.num_points)) # 例子:[4, 3, 1024]
- # model.summary()
- else:
- raise RuntimeError('Wrong arch !')
-
- # initial_learning_rate
- initial_learning_rate = args.lr*100 if args.opt == 'sgd' else args.lr
- # learning rate schedule
- # 'cos', 'cos_restarts', 'piecewise', 'inverse', 'exponential', 'natural_exponential', 'polynomial'
- # 余弦退火
- if args.scheduler == 'cos':
- decay_steps = args.epochs if args.epochs <= 200 else 200
- lr_schedule = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=initial_learning_rate,
- decay_steps=args.epochs,
- alpha=1e-5, # 1e-3
- name="cos")
- print("lr_schedule is cos : ", lr_schedule.get_config())
- # 余弦退火重启动
- if args.scheduler == 'cos_restarts':
- lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
- initial_learning_rate=initial_learning_rate,
- first_decay_steps=20,
- t_mul=20.0,
- m_mul=2.0,
- alpha=1e-5,
- name='cos_restarts')
- # 分段常数衰减
- if args.scheduler == 'piecewise':
- boundaries=[30, 60, 100, 150] # # 以 0-30 30-60 60-100 100-150 150-inf 为分段, 或以 0-20 20-40 40-60 60-80 80-inf 为分段
- values=[0.01, 0.001, 0.0001, 0.00001, 0.000001] # 各个分段学习率的值
- lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
- boundaries=boundaries, values=values, name='piecewise')
- print("lr_schedule is step: ", lr_schedule.get_config())
- # 逆时衰减
- if args.scheduler == 'inverse':
- lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(initial_learning_rate=initial_learning_rate,
- decay_steps=20,
- decay_rate=0.7,
- staircase=False,
- name="inverse")
- print("lr_schedule is step: ", lr_schedule.get_config())
- # 指数衰减
- if args.scheduler == 'exponential':
- lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
- initial_learning_rate=initial_learning_rate, decay_steps=20, decay_rate=0.96, name="exponential")
- # 自然指数衰减
- if args.scheduler == 'natural_exponential':
- class NaturalExpDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
- def __init__(self, initial_learning_rate=0.1, decay_steps=10., decay_rate=0.05, name=None):
- super().__init__()
- self.initial_learning_rate = tf.cast(initial_learning_rate, dtype=tf.float32)
- self.decay_steps = tf.cast(decay_steps, dtype=tf.float32)
- self.decay_rate = tf.cast(decay_rate, dtype=tf.float32)
- self.name = name
-
- def __call__(self, step):
- with tf.name_scope(self.name or "NaturalExpDecay") as name:
- initial_learning_rate = tf.convert_to_tensor(
- self.initial_learning_rate, name="initial_learning_rate")
- dtype = initial_learning_rate.dtype
- decay_steps = tf.cast(self.decay_steps, dtype)
- decay_rate = tf.cast(self.decay_rate, dtype)
- return initial_learning_rate * tf.math.exp(-decay_rate * (step / decay_steps), name=name)
-
- def get_config(self):
- return {
- "initial_learning_rate": self.initial_learning_rate,
- "decay_steps": self.decay_steps,
- "decay_rate": self.decay_rate,
- "name": self.name
- }
-
- lr_schedule = NaturalExpDecay(initial_learning_rate=initial_learning_rate,
- decay_steps=20,
- decay_rate=0.05,
- name='natural_exponential') # decay_rate=0.05
- # 多项式衰减
- if args.scheduler == "polynomial":
- lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
- initial_learning_rate=initial_learning_rate,
- decay_steps=20,
- end_learning_rate=0.0001,
- power=1.0,
- cycle=False,
- name="polynomial")
-
- # set optimizers
- if args.opt == 'sgd':
- opt = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=args.momentum, nesterov=False, name='SGD')
- print("SGD get config:", opt.get_config())
- elif args.opt == 'adam':
- opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False, name='Adam')
- print("Adam get config:", opt.get_config())
- elif args.opt == 'nadam':
- # 0.001,
- opt = tf.keras.optimizers.Nadam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, name='Nadam')
- print("Nadam get config:", opt.get_config())
- else:
- assert "Your optimizer is not default, please create it first!"
-
- # """
- # 在训练的开始阶段,LossScaleOptimizer 可能会跳过前几个步骤。
- # 先使用非常大的损失标度,以便快速确定最佳值。
- # 经过几个步骤后,损失标度将稳定下来,这时跳过的步骤将会很少。
- # 这一过程是自动执行的,不会影响训练质量。
- # """
- # 如果使用 mixed_float16,则明确使用损失放大
- # 会封装优化器并应用损失放大, 默认情况下,它会动态地确定损失放大
- opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
-
- # train or test
- if not args.eval:
- train(args, model, opt, ckpt_dir)
- else:
- test(model, ckpt_dir)
-
- # python main.py --model=pointmanifold_nnml --num_points=1024 --k=20 --emb_dims=1024 --hyper_times=4
- # python main.py --model=pointmanifold_lle --num_points=1024 --k=20 --emb_dims=1024 --hyper_times=1
|