|
|
@@ -30,6 +30,16 @@ from distributed import ( |
|
|
|
from op import conv2d_gradfix |
|
|
|
from non_leaking import augment, AdaptiveAugment |
|
|
|
|
|
|
|
def makedirR(c_path, is_dir=True): |
|
|
|
if is_dir and not os.path.exists(c_path): |
|
|
|
os.mkdir(c_path) |
|
|
|
elif not is_dir and not os.path.exists(c_path): # 文件新建上一级目录 |
|
|
|
if platform.system().lower() == 'windows': |
|
|
|
tmp = '\\'.join(c_path.split('\\')[:-1]) |
|
|
|
elif platform.system().lower() == 'linux': |
|
|
|
tmp = '/'.join(c_path.split('/')[:-1]) |
|
|
|
if not os.path.exists(tmp): |
|
|
|
os.mkdir(tmp) |
|
|
|
|
|
|
|
def data_sampler(dataset, shuffle, distributed): |
|
|
|
if distributed: |
|
|
@@ -314,9 +324,10 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic |
|
|
|
with torch.no_grad(): |
|
|
|
g_ema.eval() |
|
|
|
sample, _ = g_ema([sample_z]) |
|
|
|
|
|
|
|
utils.save_image( |
|
|
|
sample, |
|
|
|
f"sample/{str(i).zfill(6)}.png", |
|
|
|
f"/model/sample/{str(i).zfill(6)}.png", |
|
|
|
nrow=int(args.n_sample ** 0.5), |
|
|
|
normalize=True, |
|
|
|
range=(-1, 1), |
|
|
@@ -333,19 +344,20 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic |
|
|
|
"args": args, |
|
|
|
"ada_aug_p": ada_aug_p, |
|
|
|
}, |
|
|
|
f"checkpoint/{str(i).zfill(6)}.pt", |
|
|
|
f"/model/checkpoint/{str(i).zfill(6)}.pt", |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
device = "cuda" |
|
|
|
|
|
|
|
makedirR("/model/sample/") |
|
|
|
makedirR("/model/checkpoint/") |
|
|
|
parser = argparse.ArgumentParser(description="StyleGAN2 trainer") |
|
|
|
|
|
|
|
parser.add_argument("--path", type=str,default="/dataset", help="path to the lmdb dataset") |
|
|
|
parser.add_argument('--arch', type=str, default='stylegan2', help='model architectures (stylegan2 | swagan)') |
|
|
|
parser.add_argument( |
|
|
|
"--iter", type=int, default=1200000, help="total training iterations" |
|
|
|
"--iter", type=int, default=800000, help="total training iterations" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--batch", type=int, default=16, help="batch sizes for each gpus" |
|
|
@@ -458,9 +470,8 @@ if __name__ == "__main__": |
|
|
|
|
|
|
|
if args.arch == 'stylegan2': |
|
|
|
# from model import Generator, Discriminator |
|
|
|
# from conv3_model import Generator, Discriminator |
|
|
|
from model_forward import Generator, Discriminator |
|
|
|
# from unet_model import Generator, Discriminator |
|
|
|
from model_forward import Generator,Discriminator |
|
|
|
# from unet_model_3 import Generator, Discriminator |
|
|
|
# from unet_model_2 import Generator,Discriminator |
|
|
|
elif args.arch == 'swagan': |
|
|
|