|
- import argparse
- from io import BytesIO
- import multiprocessing
- from functools import partial
-
- from PIL import Image
- import lmdb
- from tqdm import tqdm
- from torchvision import datasets
- from torchvision.transforms import functional as trans_fn
-
-
- def resize_and_convert(img, size, resample, quality=100):
- img = trans_fn.resize(img, size, resample)
- img = trans_fn.center_crop(img, size)
- buffer = BytesIO()
- img.save(buffer, format="jpeg", quality=quality)
- val = buffer.getvalue()
-
- return val
-
-
- def resize_multiple(
- img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
- ):
- imgs = []
-
- for size in sizes:
- imgs.append(resize_and_convert(img, size, resample, quality))
-
- return imgs
-
-
- def resize_worker(img_file, sizes, resample):
- i, file = img_file
- img = Image.open(file)
- img = img.convert("RGB")
- out = resize_multiple(img, sizes=sizes, resample=resample)
-
- return i, out
-
-
- def prepare(
- env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
- ):
- resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
-
- files = sorted(dataset.imgs, key=lambda x: x[0])
- files = [(i, file) for i, (file, label) in enumerate(files)]
- total = 0
-
- with multiprocessing.Pool(n_worker) as pool:
- for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
- for size, img in zip(sizes, imgs):
- key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
-
- with env.begin(write=True) as txn:
- txn.put(key, img)
-
- total += 1
-
- with env.begin(write=True) as txn:
- txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Preprocess images for model training")
- parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
- parser.add_argument(
- "--size",
- type=str,
- default="128,256,512,1024",
- help="resolutions of images for the dataset",
- )
- parser.add_argument(
- "--n_worker",
- type=int,
- default=8,
- help="number of workers for preparing dataset",
- )
- parser.add_argument(
- "--resample",
- type=str,
- default="lanczos",
- help="resampling methods for resizing images",
- )
- parser.add_argument("path", type=str, help="path to the image dataset")
-
- args = parser.parse_args()
-
- resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
- resample = resample_map[args.resample]
-
- sizes = [int(s.strip()) for s in args.size.split(",")]
-
- print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
-
- imgset = datasets.ImageFolder(args.path)
-
- with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
- prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
|