|
- import argparse
- import pickle
- import os
-
- import torch
- from torch import nn
- from torch.nn import functional as F
- from torch.utils.data import DataLoader
- from torchvision import transforms
- from torchvision.models import inception_v3, Inception3
- import numpy as np
- from tqdm import tqdm
-
- from inception import InceptionV3
- from dataset import MultiResolutionDataset
-
-
- class Inception3Feature(Inception3):
- def forward(self, x):
- if x.shape[2] != 299 or x.shape[3] != 299:
- x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
-
- x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
- x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
- x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
- x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
-
- x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
- x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
- x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
-
- x = self.Mixed_5b(x) # 35 x 35 x 192
- x = self.Mixed_5c(x) # 35 x 35 x 256
- x = self.Mixed_5d(x) # 35 x 35 x 288
-
- x = self.Mixed_6a(x) # 35 x 35 x 288
- x = self.Mixed_6b(x) # 17 x 17 x 768
- x = self.Mixed_6c(x) # 17 x 17 x 768
- x = self.Mixed_6d(x) # 17 x 17 x 768
- x = self.Mixed_6e(x) # 17 x 17 x 768
-
- x = self.Mixed_7a(x) # 17 x 17 x 768
- x = self.Mixed_7b(x) # 8 x 8 x 1280
- x = self.Mixed_7c(x) # 8 x 8 x 2048
-
- x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
-
- return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
-
-
- def load_patched_inception_v3():
- # inception = inception_v3(pretrained=True)
- # inception_feat = Inception3Feature()
- # inception_feat.load_state_dict(inception.state_dict())
- inception_feat = InceptionV3([3], normalize_input=False)
-
- return inception_feat
-
-
- @torch.no_grad()
- def extract_features(loader, inception, device):
- pbar = tqdm(loader)
-
- feature_list = []
-
- for img in pbar:
- img = img.to(device)
- feature = inception(img)[0].view(img.shape[0], -1)
- feature_list.append(feature.to("cpu"))
-
- features = torch.cat(feature_list, 0)
-
- return features
-
-
- if __name__ == "__main__":
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- parser = argparse.ArgumentParser(
- description="Calculate Inception v3 features for datasets"
- )
- parser.add_argument(
- "--size",
- type=int,
- default=256,
- help="image sizes used for embedding calculation",
- )
- parser.add_argument(
- "--batch", default=64, type=int, help="batch size for inception networks"
- )
- parser.add_argument(
- "--n_sample",
- type=int,
- default=50000,
- help="number of samples used for embedding calculation",
- )
- parser.add_argument(
- "--flip", action="store_true", help="apply random flipping to real images"
- )
- parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
-
- args = parser.parse_args()
-
- inception = load_patched_inception_v3()
- inception = nn.DataParallel(inception).eval().to(device)
-
- transform = transforms.Compose(
- [
- transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
- transforms.ToTensor(),
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
- ]
- )
-
- dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
- loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
-
- features = extract_features(loader, inception, device).numpy()
-
- features = features[: args.n_sample]
-
- print(f"extracted {features.shape[0]} features")
-
- mean = np.mean(features, 0)
- cov = np.cov(features, rowvar=False)
-
- name = os.path.splitext(os.path.basename(args.path))[0]
-
- with open(f"inception_{name}.pkl", "wb") as f:
- pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
|