|
- import time, sys, glob, argparse
-
- import os
-
- sys.path.append("/code/pytorch/3djcg/3djcg")
-
- import importlib
- import numpy as np
- import torch
- import h5py
- import random
-
- random.seed() # 这句对下面的random.sample没影响
- from data_loader import PCDataset, make_data_loader
- from trainer import Trainer
- from pcc_model import PCCModel
-
- def parse_args(): # 已改
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-
- parser.add_argument("--dataset", type=str, default='/dataset')
- parser.add_argument("--dataset_num", type=int, default=2.8e5) # 2.8e5可以先搞个小的试试 300,001 个文件
- parser.add_argument(
- "--alpha", type=float, default=10, dest="alpha", # 6
- help="weights for distoration.")
- parser.add_argument(
- "--beta", type=float, default=3., dest="beta",
- help="Weight for empty position.")
- parser.add_argument(
- "--gamma", type=float, default=1.3, dest="gamma",
- help="Weight for hyper likelihoods.")
- parser.add_argument(
- "--delta", type=float, default=3., dest="delta",
- help="Weight for latent likelihoods.")
- parser.add_argument(
- "--lr", type=float, default=2e-4, dest="lr", # 2e-4
- help="learning rate.")
- parser.add_argument("--epoch", type=int, default=30) #
- # parser.add_argument(
- # "--num_iteration", type=int, default=3e5, dest="num_iteration",
- # help="number of iteration.")
- parser.add_argument(
- "--prefix", type=str, default='hyper_mgpu4', dest="prefix",
- help="prefix of checkpoints/logger.")
- parser.add_argument(
- "--init_ckpt", type=str, default='', dest="init_ckpt", # /userhome/PCGCv1/pytorch/ckpts/hyper_/epoch_35.pth
- help='initial checkpoint directory.')
- # parser.add_argument(
- # "--reset_optimizer", type=int, default=0, dest="reset_optimizer",
- # help='reset optimizer (1) or not.')
- parser.add_argument(
- "--lower_bound", type=float, default=1e-9, dest="lower_bound",
- help="lower bound of scale. 1e-5 or 1e-9")
- parser.add_argument(
- "--batch_size", type=int, default=16, dest="batch_size", # 48会爆显存
- help='batch_size')
-
- args = parser.parse_args()
-
- return args
-
-
- class TrainingConfig(): # 已改
- def __init__(self, logdir, ckptdir, init_ckpt, alpha, beta, gamma, delta, lr):
- self.logdir = logdir
- if not os.path.exists(self.logdir): os.makedirs(self.logdir)
- self.ckptdir = ckptdir
- if not os.path.exists(self.ckptdir): os.makedirs(self.ckptdir)
- self.init_ckpt = init_ckpt
- self.alpha = alpha
- self.beta = beta
- self.lr = lr
- self.gamma = gamma # weight of hyper prior.
- self.delta = delta # weight of latent representation.
-
-
- # 已改
- if __name__ == '__main__':
- # log
- args = parse_args()
- # Define parameters.
- RATIO_EVAL = 9 #
- # NUM_ITEATION = int(args.num_iteration)
- # print('lower bound of scale:', lower_bound)
- # reset_optimizer = bool(args.reset_optimizer)
- # print('reset_optimizer:::', reset_optimizer)
-
- training_config = TrainingConfig(
- logdir=os.path.join('/model', args.prefix),
- ckptdir=os.path.join('/model', args.prefix), # 保存当前训练的模型
- init_ckpt=args.init_ckpt, # 初始化模型
- alpha=args.alpha,
- beta=args.beta,
- gamma=args.gamma,
- delta=args.delta,
- lr=args.lr)
- # model
- model = PCCModel(lower_bound=args.lower_bound)
- # trainer
- trainer = Trainer(config=training_config, model=model)
-
- # dataset
- # filedirs = sorted(glob.glob(args.dataset + '*.h5'))[:int(args.dataset_num)]
- train_src = np.loadtxt('/code/pytorch/scannetv2_train.txt', dtype=str)
- train_datalist = []
- for scene in train_src:
- train_datalist.append(f'/dataset/scannet/scans/{scene}/{scene}_vh_clean_2.ply')
- val_src = np.loadtxt('/code/pytorch/scannetv2_val.txt', dtype=str)
- val_datalist = []
- for scene in val_src:
- val_datalist.append(f'/dataset/scannet/scans/{scene}/{scene}_vh_clean_2.ply')
- # print("all files len(filedirs):",len(filedirs)) #有这句执行太慢
- # training
- print('=====Begin Training=====')
- for epoch in range(0, args.epoch):
- if epoch > 0:
- trainer.update_lr(lr=max(trainer.config.lr / 2, 1e-5)) # update lr
- # train_list = random.sample(filedirs[len(filedirs) // RATIO_EVAL:],
- # 1000 * args.batch_size) # 5000 每个epoch的迭代次数 每次随机生成,但偶尔有重复的
- train_dataset = PCDataset(train_datalist)
- train_dataloader = make_data_loader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True,
- num_workers=6, repeat=False)
- trainer.train(train_dataloader)
- # eval_list = random.sample(filedirs[:len(filedirs) // RATIO_EVAL], 10 * args.batch_size) # 10
- test_dataset = PCDataset(val_datalist)
- test_dataloader = make_data_loader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False,
- num_workers=3, repeat=False)
- trainer.test(test_dataloader, 'Test')
|