|
|
@@ -0,0 +1,130 @@ |
|
|
|
import argparse |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch.nn import functional as F |
|
|
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
import lpips |
|
|
|
from model_forward import Generator |
|
|
|
|
|
|
|
|
|
|
|
def normalize(x): |
|
|
|
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) |
|
|
|
|
|
|
|
|
|
|
|
def slerp(a, b, t): |
|
|
|
a = normalize(a) |
|
|
|
b = normalize(b) |
|
|
|
d = (a * b).sum(-1, keepdim=True) |
|
|
|
p = t * torch.acos(d) |
|
|
|
c = normalize(b - d * a) |
|
|
|
d = a * torch.cos(p) + c * torch.sin(p) |
|
|
|
|
|
|
|
return normalize(d) |
|
|
|
|
|
|
|
|
|
|
|
def lerp(a, b, t): |
|
|
|
return a + (b - a) * t |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
device = "cuda" |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Perceptual Path Length calculator") |
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
"--space", default="w",choices=["z", "w"], help="space that PPL calculated with" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--batch", type=int, default=128, help="batch size for the models" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--n_sample", |
|
|
|
type=int, |
|
|
|
default=217038, |
|
|
|
help="number of the samples for calculating PPL", |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--size", type=int, default=128, help="output image sizes of the generator" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--eps", type=float, default=1e-4, help="epsilon for numerical stability" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--crop", action="store_true", help="apply center crop to the images" |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--sampling", |
|
|
|
default="end", |
|
|
|
choices=["end", "full"], |
|
|
|
help="set endpoint sampling method", |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"--ckpt", default="./checkpoint/330000.pt",metavar="CHECKPOINT", help="path to the model checkpoints" |
|
|
|
) |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
latent_dim = 512 |
|
|
|
|
|
|
|
ckpt = torch.load(args.ckpt) |
|
|
|
|
|
|
|
g = Generator(args.size, latent_dim, 8).to(device) |
|
|
|
g.load_state_dict(ckpt["g_ema"]) |
|
|
|
g.eval() |
|
|
|
|
|
|
|
percept = lpips.PerceptualLoss( |
|
|
|
model="net-lin", net="vgg", use_gpu=device.startswith("cuda") |
|
|
|
) |
|
|
|
|
|
|
|
distances = [] |
|
|
|
|
|
|
|
n_batch = args.n_sample // args.batch |
|
|
|
resid = args.n_sample - (n_batch * args.batch) |
|
|
|
batch_sizes = [args.batch] * n_batch + [resid] |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for batch in tqdm(batch_sizes): |
|
|
|
noise,forward_noise = g.make_noise() |
|
|
|
|
|
|
|
inputs = torch.randn([batch * 2, latent_dim], device=device) |
|
|
|
if args.sampling == "full": |
|
|
|
lerp_t = torch.rand(batch, device=device) |
|
|
|
else: |
|
|
|
lerp_t = torch.zeros(batch, device=device) |
|
|
|
|
|
|
|
if args.space == "w": |
|
|
|
latent = g.get_latent(inputs) |
|
|
|
latent_t0, latent_t1 = latent[::2], latent[1::2] |
|
|
|
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) |
|
|
|
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) |
|
|
|
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) |
|
|
|
|
|
|
|
image, _ = g([latent_e], input_is_latent=True, noise=noise,forward_noise=forward_noise) |
|
|
|
|
|
|
|
if args.crop: |
|
|
|
c = image.shape[2] // 8 |
|
|
|
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] |
|
|
|
|
|
|
|
factor = image.shape[2] // 256 |
|
|
|
|
|
|
|
if factor > 1: |
|
|
|
image = F.interpolate( |
|
|
|
image, size=(256, 256), mode="bilinear", align_corners=False |
|
|
|
) |
|
|
|
|
|
|
|
dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( |
|
|
|
args.eps ** 2 |
|
|
|
) |
|
|
|
distances.append(dist.to("cpu").numpy()) |
|
|
|
|
|
|
|
distances = np.concatenate(distances, 0) |
|
|
|
|
|
|
|
lo = np.percentile(distances, 1, interpolation="lower") |
|
|
|
hi = np.percentile(distances, 99, interpolation="higher") |
|
|
|
filtered_dist = np.extract( |
|
|
|
np.logical_and(lo <= distances, distances <= hi), distances |
|
|
|
) |
|
|
|
|
|
|
|
print("ppl:", filtered_dist.mean()) |