Browse Source

添加 'ppl_forward.py'

main
mohenghui 1 month ago
parent
commit
7f3ba83a4d
1 changed files with 130 additions and 0 deletions
  1. +130
    -0
      ppl_forward.py

+ 130
- 0
ppl_forward.py View File

@@ -0,0 +1,130 @@
import argparse

import torch
from torch.nn import functional as F
import numpy as np
from tqdm import tqdm

import lpips
from model_forward import Generator


def normalize(x):
return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))


def slerp(a, b, t):
a = normalize(a)
b = normalize(b)
d = (a * b).sum(-1, keepdim=True)
p = t * torch.acos(d)
c = normalize(b - d * a)
d = a * torch.cos(p) + c * torch.sin(p)

return normalize(d)


def lerp(a, b, t):
return a + (b - a) * t


if __name__ == "__main__":
device = "cuda"

parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")

parser.add_argument(
"--space", default="w",choices=["z", "w"], help="space that PPL calculated with"
)
parser.add_argument(
"--batch", type=int, default=128, help="batch size for the models"
)
parser.add_argument(
"--n_sample",
type=int,
default=217038,
help="number of the samples for calculating PPL",
)
parser.add_argument(
"--size", type=int, default=128, help="output image sizes of the generator"
)
parser.add_argument(
"--eps", type=float, default=1e-4, help="epsilon for numerical stability"
)
parser.add_argument(
"--crop", action="store_true", help="apply center crop to the images"
)
parser.add_argument(
"--sampling",
default="end",
choices=["end", "full"],
help="set endpoint sampling method",
)
parser.add_argument(
"--ckpt", default="./checkpoint/330000.pt",metavar="CHECKPOINT", help="path to the model checkpoints"
)

args = parser.parse_args()

latent_dim = 512

ckpt = torch.load(args.ckpt)

g = Generator(args.size, latent_dim, 8).to(device)
g.load_state_dict(ckpt["g_ema"])
g.eval()

percept = lpips.PerceptualLoss(
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)

distances = []

n_batch = args.n_sample // args.batch
resid = args.n_sample - (n_batch * args.batch)
batch_sizes = [args.batch] * n_batch + [resid]

with torch.no_grad():
for batch in tqdm(batch_sizes):
noise,forward_noise = g.make_noise()

inputs = torch.randn([batch * 2, latent_dim], device=device)
if args.sampling == "full":
lerp_t = torch.rand(batch, device=device)
else:
lerp_t = torch.zeros(batch, device=device)

if args.space == "w":
latent = g.get_latent(inputs)
latent_t0, latent_t1 = latent[::2], latent[1::2]
latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)

image, _ = g([latent_e], input_is_latent=True, noise=noise,forward_noise=forward_noise)

if args.crop:
c = image.shape[2] // 8
image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]

factor = image.shape[2] // 256

if factor > 1:
image = F.interpolate(
image, size=(256, 256), mode="bilinear", align_corners=False
)

dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
args.eps ** 2
)
distances.append(dist.to("cpu").numpy())

distances = np.concatenate(distances, 0)

lo = np.percentile(distances, 1, interpolation="lower")
hi = np.percentile(distances, 99, interpolation="higher")
filtered_dist = np.extract(
np.logical_and(lo <= distances, distances <= hi), distances
)

print("ppl:", filtered_dist.mean())

Loading…
Cancel
Save