|
- import argparse
- import math
- import os
-
- import torch
- from torch import optim
- from torch.nn import functional as F
- from torchvision import transforms
- from PIL import Image
- from tqdm import tqdm
-
- import lpips
- from model import Generator
-
-
- def noise_regularize(noises):
- loss = 0
-
- for noise in noises:
- size = noise.shape[2]
-
- while True:
- loss = (
- loss
- + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
- + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
- )
-
- if size <= 8:
- break
-
- noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
- noise = noise.mean([3, 5])
- size //= 2
-
- return loss
-
-
- def noise_normalize_(noises):
- for noise in noises:
- mean = noise.mean()
- std = noise.std()
-
- noise.data.add_(-mean).div_(std)
-
-
- def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
- lr_ramp = min(1, (1 - t) / rampdown)
- lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
- lr_ramp = lr_ramp * min(1, t / rampup)
-
- return initial_lr * lr_ramp
-
-
- def latent_noise(latent, strength):
- noise = torch.randn_like(latent) * strength
-
- return latent + noise
-
-
- def make_image(tensor):
- return (
- tensor.detach()
- .clamp_(min=-1, max=1)
- .add(1)
- .div_(2)
- .mul(255)
- .type(torch.uint8)
- .permute(0, 2, 3, 1)
- .to("cpu")
- .numpy()
- )
-
-
- if __name__ == "__main__":
- device = "cuda"
-
- parser = argparse.ArgumentParser(
- description="Image projector to the generator latent spaces"
- )
- parser.add_argument(
- "--ckpt", type=str, required=True, help="path to the model checkpoint"
- )
- parser.add_argument(
- "--size", type=int, default=256, help="output image sizes of the generator"
- )
- parser.add_argument(
- "--lr_rampup",
- type=float,
- default=0.05,
- help="duration of the learning rate warmup",
- )
- parser.add_argument(
- "--lr_rampdown",
- type=float,
- default=0.25,
- help="duration of the learning rate decay",
- )
- parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
- parser.add_argument(
- "--noise", type=float, default=0.05, help="strength of the noise level"
- )
- parser.add_argument(
- "--noise_ramp",
- type=float,
- default=0.75,
- help="duration of the noise level decay",
- )
- parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
- parser.add_argument(
- "--noise_regularize",
- type=float,
- default=1e5,
- help="weight of the noise regularization",
- )
- parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
- parser.add_argument(
- "--w_plus",
- action="store_true",
- help="allow to use distinct latent codes to each layers",
- )
- parser.add_argument(
- "files", metavar="FILES", nargs="+", help="path to image files to be projected"
- )
-
- args = parser.parse_args()
-
- n_mean_latent = 10000
-
- resize = min(args.size, 256)
-
- transform = transforms.Compose(
- [
- transforms.Resize(resize),
- transforms.CenterCrop(resize),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
- ]
- )
-
- imgs = []
-
- for imgfile in args.files:
- img = transform(Image.open(imgfile).convert("RGB"))
- imgs.append(img)
-
- imgs = torch.stack(imgs, 0).to(device)
-
- g_ema = Generator(args.size, 512, 8)
- g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
- g_ema.eval()
- g_ema = g_ema.to(device)
-
- with torch.no_grad():
- noise_sample = torch.randn(n_mean_latent, 512, device=device)
- latent_out = g_ema.style(noise_sample)
-
- latent_mean = latent_out.mean(0)
- latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
-
- percept = lpips.PerceptualLoss(
- model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
- )
-
- noises_single = g_ema.make_noise()
- noises = []
- for noise in noises_single:
- noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
-
- latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
-
- if args.w_plus:
- latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
-
- latent_in.requires_grad = True
-
- for noise in noises:
- noise.requires_grad = True
-
- optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
-
- pbar = tqdm(range(args.step))
- latent_path = []
-
- for i in pbar:
- t = i / args.step
- lr = get_lr(t, args.lr)
- optimizer.param_groups[0]["lr"] = lr
- noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
- latent_n = latent_noise(latent_in, noise_strength.item())
-
- img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
-
- batch, channel, height, width = img_gen.shape
-
- if height > 256:
- factor = height // 256
-
- img_gen = img_gen.reshape(
- batch, channel, height // factor, factor, width // factor, factor
- )
- img_gen = img_gen.mean([3, 5])
-
- p_loss = percept(img_gen, imgs).sum()
- n_loss = noise_regularize(noises)
- mse_loss = F.mse_loss(img_gen, imgs)
-
- loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- noise_normalize_(noises)
-
- if (i + 1) % 100 == 0:
- latent_path.append(latent_in.detach().clone())
-
- pbar.set_description(
- (
- f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
- f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
- )
- )
-
- img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
-
- filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
-
- img_ar = make_image(img_gen)
-
- result_file = {}
- for i, input_name in enumerate(args.files):
- noise_single = []
- for noise in noises:
- noise_single.append(noise[i : i + 1])
-
- result_file[input_name] = {
- "img": img_gen[i],
- "latent": latent_in[i],
- "noise": noise_single,
- }
-
- img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
- pil_img = Image.fromarray(img_ar[i])
- pil_img.save(img_name)
-
- torch.save(result_file, filename)
|