|
- import argparse
- import math
- import os
- import torch
- from torchvision import utils
-
- from basicsr.archs.stylegan2_arch import StyleGAN2Generator
- from basicsr.utils import set_random_seed
-
-
- def generate(args, g_ema, device, mean_latent, randomize_noise):
-
- with torch.no_grad():
- g_ema.eval()
- for i in range(args.pics):
- sample_z = torch.randn(args.sample, args.latent, device=device)
-
- sample, _ = g_ema([sample_z],
- truncation=args.truncation,
- randomize_noise=randomize_noise,
- truncation_latent=mean_latent)
-
- utils.save_image(
- sample,
- f'samples/{str(i).zfill(6)}.png',
- nrow=int(math.sqrt(args.sample)),
- normalize=True,
- range=(-1, 1),
- )
-
-
- if __name__ == '__main__':
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- parser = argparse.ArgumentParser()
-
- parser.add_argument('--size', type=int, default=1024)
- parser.add_argument('--sample', type=int, default=1)
- parser.add_argument('--pics', type=int, default=1)
- parser.add_argument('--truncation', type=float, default=1)
- parser.add_argument('--truncation_mean', type=int, default=4096)
- parser.add_argument(
- '--ckpt',
- type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-3ab41b38.pth' # noqa: E501
- )
- parser.add_argument('--channel_multiplier', type=int, default=2)
- parser.add_argument('--randomize_noise', type=bool, default=True)
-
- args = parser.parse_args()
-
- args.latent = 512
- args.n_mlp = 8
- os.makedirs('samples', exist_ok=True)
- set_random_seed(2020)
-
- g_ema = StyleGAN2Generator(
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
- checkpoint = torch.load(args.ckpt)['params_ema']
-
- g_ema.load_state_dict(checkpoint)
-
- if args.truncation < 1:
- with torch.no_grad():
- mean_latent = g_ema.mean_latent(args.truncation_mean)
- else:
- mean_latent = None
-
- generate(args, g_ema, device, mean_latent, args.randomize_noise)
|