Browse Source

更新 'train.py'

main
mohenghui 1 month ago
parent
commit
e1b813d8b7
1 changed files with 17 additions and 6 deletions
  1. +17
    -6
      train.py

+ 17
- 6
train.py View File

@@ -30,6 +30,16 @@ from distributed import (
from op import conv2d_gradfix
from non_leaking import augment, AdaptiveAugment

def makedirR(c_path, is_dir=True):
if is_dir and not os.path.exists(c_path):
os.mkdir(c_path)
elif not is_dir and not os.path.exists(c_path): # 文件新建上一级目录
if platform.system().lower() == 'windows':
tmp = '\\'.join(c_path.split('\\')[:-1])
elif platform.system().lower() == 'linux':
tmp = '/'.join(c_path.split('/')[:-1])
if not os.path.exists(tmp):
os.mkdir(tmp)

def data_sampler(dataset, shuffle, distributed):
if distributed:
@@ -314,9 +324,10 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
with torch.no_grad():
g_ema.eval()
sample, _ = g_ema([sample_z])
utils.save_image(
sample,
f"sample/{str(i).zfill(6)}.png",
f"/model/sample/{str(i).zfill(6)}.png",
nrow=int(args.n_sample ** 0.5),
normalize=True,
range=(-1, 1),
@@ -333,19 +344,20 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic
"args": args,
"ada_aug_p": ada_aug_p,
},
f"checkpoint/{str(i).zfill(6)}.pt",
f"/model/checkpoint/{str(i).zfill(6)}.pt",
)


if __name__ == "__main__":
device = "cuda"
makedirR("/model/sample/")
makedirR("/model/checkpoint/")
parser = argparse.ArgumentParser(description="StyleGAN2 trainer")

parser.add_argument("--path", type=str,default="/dataset", help="path to the lmdb dataset")
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)')
parser.add_argument(
"--iter", type=int, default=1200000, help="total training iterations"
"--iter", type=int, default=800000, help="total training iterations"
)
parser.add_argument(
"--batch", type=int, default=16, help="batch sizes for each gpus"
@@ -458,9 +470,8 @@ if __name__ == "__main__":

if args.arch == 'stylegan2':
# from model import Generator, Discriminator
# from conv3_model import Generator, Discriminator
from model_forward import Generator, Discriminator
# from unet_model import Generator, Discriminator
from model_forward import Generator,Discriminator
# from unet_model_3 import Generator, Discriminator
# from unet_model_2 import Generator,Discriminator
elif args.arch == 'swagan':


Loading…
Cancel
Save