|
- import os
- import sys
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('--exp_name', '-e', type=str, default='pugan2', help='experiment name')
- parser.add_argument('--project_dir',type=str,default="/userhome/zyc/PUGAN-pytorch-master")
- parser.add_argument('--model_save_dir',type=str,default="/userhome/zyc/PUGAN-pytorch-master/checkpoints")
- parser.add_argument('--dataset_dir',type=str,default="/userhome/zyc/PUGAN-pytorch-master/data/train/PUGAN_poisson_256_poisson_1024.h5")
- parser.add_argument('--batch_size',type=int,default=28)
- parser.add_argument('--nepoch',type=int,default=100)
- parser.add_argument('--model_save_interval',type=int,default=10)
- parser.add_argument('--up_ratio',type=int,default=4)
- parser.add_argument('--patch_num_point',type=int,default=256)
- parser.add_argument('--lr_D',type=int,default=1e-4)
- parser.add_argument('--lr_G',type=int,default=1e-3)
- parser.add_argument('--emd_w',type=int,default=100.0)
- parser.add_argument('--uniform_w',type=int,default=10.0)
- parser.add_argument('--gan_w',type=int,default=0.5)
- parser.add_argument('--repulsion_w',type=int,default=5)
- parser.add_argument('--use_gan',type=bool,default=True)
- parser.add_argument('--gpu',type=str,default='0')
-
- args = parser.parse_args()
- os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
- sys.path.append('../')
- import torch
- from model import Generator,Discriminator
- from data_loader import PUGAN_Dataset
- import time
- from torch.utils import data
- from torch.optim import Adam
- from torch.optim.lr_scheduler import MultiStepLR
- from loss import Loss_fc as Loss
- import datetime
- import torch.nn as nn
-
-
- def xavier_init(m):
- classname = m.__class__.__name__
- #print(classname)
- if classname.find('Conv') != -1:
- nn.init.xavier_normal(m.weight)
- elif classname.find('Linear')!=-1:
- nn.init.xavier_normal(m.weight)
- elif classname.find('BatchNorm') != -1:
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
-
- def train(args):
- start_t=time.time()
-
- log_dir=os.path.join(args.model_save_dir,args.exp_name)
- if os.path.exists(log_dir)==False:
- os.makedirs(log_dir)
-
- trainloader=PUGAN_Dataset(h5_file_path=args.dataset_dir)
-
- num_workers=4
- train_data_loader=data.DataLoader(dataset=trainloader,batch_size=args.batch_size,shuffle=True,
- num_workers=num_workers,pin_memory=True,drop_last=True)
- device=torch.device('cuda'if torch.cuda.is_available() else 'cpu')
-
- G_model=Generator()
- G_model.apply(xavier_init)
- G_model=torch.nn.DataParallel(G_model).cuda()
- D_model=Discriminator(in_channels=3)
- D_model.apply(xavier_init)
- D_model=torch.nn.DataParallel(D_model).cuda()
-
- G_model.train()
- D_model.train()
-
- optimizer_D=Adam(D_model.parameters(),lr=args.lr_D,betas=(0.9,0.999)) # 0.0001
- optimizer_G=Adam(G_model.parameters(),lr=args.lr_G,betas=(0.9,0.999)) # 0.001
-
- D_scheduler = MultiStepLR(optimizer_D,[60,80],gamma=0.2) #(60, 80) 0.2
- G_scheduler = MultiStepLR(optimizer_G,[60,80],gamma=0.2) #[20,40,60,80] 0.5
-
- Loss_fn=Loss()
-
- print("preparation time is %fs" % (time.time() - start_t))
- step=1
- for e in range(args.nepoch):
- D_scheduler.step()
- G_scheduler.step()
- for batch_id,(input_data, gt_data, radius_data) in enumerate(train_data_loader):
- optimizer_G.zero_grad()
- optimizer_D.zero_grad()
-
- input_data=input_data[:,:,0:3].permute(0,2,1).float().cuda()
- gt_data=gt_data[:,:,0:3].permute(0,2,1).float().cuda()
-
- start_t_batch=time.time()
- output_point_cloud=G_model(input_data)
- #print('=========================',output_point_cloud.shape)
-
- repulsion_loss = Loss_fn.get_repulsion_loss(output_point_cloud.permute(0, 2, 1))
- uniform_loss = Loss_fn.get_uniform_loss(output_point_cloud.permute(0, 2, 1))
- #print(output_point_cloud.shape,gt_data.shape)
- emd_loss = Loss_fn.get_emd_loss(output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1))
-
-
- fake_pred = D_model(output_point_cloud.detach())
- d_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred,label=False)
- d_loss_fake.backward()
- optimizer_D.step()
-
- real_pred = D_model(gt_data.detach())
- d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred, label=True)
- d_loss_real.backward()
- optimizer_D.step()
-
- d_loss=d_loss_real + d_loss_fake
-
- fake_pred=D_model(output_point_cloud)
- g_loss=Loss_fn.get_generator_loss(fake_pred)
-
- #print(repulsion_loss,uniform_loss,emd_loss)
- total_G_loss=args.uniform_w*uniform_loss+args.emd_w*emd_loss+ \
- repulsion_loss*args.repulsion_w+ g_loss*args.gan_w
-
-
- #total_G_loss=emd_loss
- total_G_loss.backward()
- optimizer_G.step()
-
- current_lr_D=optimizer_D.state_dict()['param_groups'][0]['lr']
- current_lr_G=optimizer_G.state_dict()['param_groups'][0]['lr']
-
- msg="{:0>8},{}:{}, [{}/{}], {}: {},{}:{},{}:{:.8f}".format(
- str(datetime.timedelta(seconds=round(time.time() - start_t))),
- "epoch",
- e,
- batch_id + 1,
- len(train_data_loader),
- "total_G_loss",
- total_G_loss.item(),
- "iter time",
- (time.time() - start_t_batch),
- 'lr_D',
- current_lr_G
- )
- print(msg)
-
- # if step % 10000 == 0:
- # lr_D = adjust_learning_rate(optimizer_D)
- # lr_G = adjust_learning_rate(optimizer_G)
-
- # for param_group in optimizer_D.param_groups:
- # param_group["lr"] = lr_D
-
- # for param_group in optimizer_G.param_groups:
- # param_group["lr"] = lr_G
-
- # print("lr of optimizer_D is set to {:.6f}, lr of optimizer_G is set to {:.6f}".format( \
- # optimizer_D.param_groups[0]['lr'], optimizer_G.param_groups[0]['lr']))
-
- step+=1
-
-
- if (e+1) % args.model_save_interval == 0 and e > 0:
- model_save_dir = os.path.join(args.model_save_dir, args.exp_name)
- if os.path.exists(model_save_dir) == False:
- os.makedirs(model_save_dir)
- D_ckpt_model_filename = "D_iter_%d.pth" % (e)
- G_ckpt_model_filename = "G_iter_%d.pth" % (e)
- D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename)
- G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename)
- torch.save(D_model.module.state_dict(), D_model_save_path)
- torch.save(G_model.module.state_dict(), G_model_save_path)
-
-
-
-
- if __name__=="__main__":
- train(args)
|