|
- import os
-
- import torch
- import torch.distributed as dist
- import torch.multiprocessing
- from tqdm import tqdm
- from torch.cuda.amp import autocast as autocast
-
- from GAN.datasets.tensorf.replica_dataset import ReplicaVolumeDataset
- from GAN.networks.render.if_clight_renderer import Renderer
- from GAN.train.trainers.trainer import Trainer
-
- from opt_GAN import config_parser
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def test(cfg):
- pass
-
- # for 每个epoch
- # load volume
- # for item in dataset
- # 渲染一张图计算loss
-
-
- def train(cfg):
- begin_epoch = 0
- GAN_epochs = 200
-
-
- save_dir = os.path.join(cfg.basedir,cfg.expname)
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- train_loader = ReplicaVolumeDataset(cfg)
- input_dim = train_loader.tensorf_loader.train_dataset.num_valid_semantic_class
- trainer = Trainer(cfg,input_dim,train_loader.tensorf_loader.tensorf)
- save_img_path = os.path.join(save_dir, 'train_img')
- if not os.path.exists(save_img_path):
- os.makedirs(save_img_path)
- for epoch in tqdm(range(begin_epoch,GAN_epochs)):
- if epoch > 0:
- train_loader.update(epoch)
- trainer.train(epoch, train_loader,save_path=save_img_path)
- if epoch % 100 == 0:
- trainer.save_GAN(save_dir,epoch)
-
- trainer.save_GAN(save_dir, GAN_epochs,True)
-
- # begin_epoch = load_model(network,
- # optimizer,
- # scheduler,
- # recorder,
- # cfg.trained_model_dir,
- # resume=cfg.resume)
- # set_lr_scheduler(cfg, scheduler)
- #
- # train_loader = make_data_loader(cfg,
- # is_train=True,
- # is_distributed=cfg.distributed,
- # max_iter=cfg.ep_iter)
- # val_loader = make_data_loader(cfg, is_train=False)
- #
- # for epoch in range(begin_epoch, cfg.train.epoch):
- # recorder.epoch = epoch
- # if cfg.distributed:
- # train_loader.batch_sampler.sampler.set_epoch(epoch)
- #
- # trainer.train(epoch, train_loader, optimizer, recorder)
- # scheduler.step()
- #
- # if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0:
- # save_model(network, optimizer, scheduler, recorder,
- # cfg.trained_model_dir, epoch)
- #
- # if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0:
- # save_model(network,
- # optimizer,
- # scheduler,
- # recorder,
- # cfg.trained_model_dir,
- # epoch,
- # last=True)
- #
- # if (epoch + 1) % cfg.eval_ep == 0:
- # trainer.val(epoch, val_loader, evaluator, recorder)
-
- def synchronize():
- """
- Helper function to synchronize (barrier) among all processes when
- using distributed training
- """
- if not dist.is_available():
- return
- if not dist.is_initialized():
- return
- world_size = dist.get_world_size()
- if world_size == 1:
- return
- dist.barrier()
-
- def main(cfg):
- if cfg.distributed:
- cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count()
- torch.cuda.set_device(cfg.local_rank)
- torch.distributed.init_process_group(backend="nccl",
- init_method="env://")
- synchronize()
-
- if cfg.training:
- train(cfg)
- else:
- test(cfg)
-
-
- if __name__ == "__main__":
- torch.cuda.empty_cache()
- args = config_parser()
- print(args)
- main(args)
|