|
- # MIT License
-
- # Copyright (c) 2019 Kim Seonghyeon
-
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
-
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
-
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
-
- import os
- from glob import glob
- import argparse
- from io import BytesIO
- import multiprocessing
- from functools import partial
-
- from PIL import Image
- import lmdb
- from tqdm import tqdm
- from torchvision.transforms import functional as trans_fn
-
-
- def resize_and_convert(img, size, resample, quality=100):
- img = trans_fn.resize(img, size, resample)
- img = trans_fn.center_crop(img, size)
- buffer = BytesIO()
- img.save(buffer, format="jpeg", quality=quality)
- val = buffer.getvalue()
-
- return val
-
-
- def resize_multiple(
- img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
- ):
- imgs = []
-
- for size in sizes:
- imgs.append(resize_and_convert(img, size, resample, quality))
-
- return imgs
-
-
- def resize_worker(img_file, sizes, resample):
- i, file = img_file
- try:
- img = Image.open(file)
- img = img.convert("RGB")
- except:
- print(file)
- raise ValueError("~!")
- out = resize_multiple(img, sizes=sizes, resample=resample)
-
- return i, out
-
- def find_images(path):
- files = list()
- IMAGE_EXTENSIONS = {'jpg', 'png', 'jpeg', 'webp'}
- IMAGE_EXTENSIONS = IMAGE_EXTENSIONS.union({f.upper() for f in IMAGE_EXTENSIONS})
- for ext in IMAGE_EXTENSIONS:
- files += glob(f'{path}/**/*.{ext}', recursive=True)
- files = sorted(files)
- return list(enumerate(files))
-
- def prepare(
- env, files, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
- ):
- resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
-
- total = 0
- with multiprocessing.Pool(n_worker) as pool:
- for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
- for size, img in zip(sizes, imgs):
- key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
-
- with env.begin(write=True) as txn:
- txn.put(key, img)
-
- total += 1
-
- with env.begin(write=True) as txn:
- txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Preprocess images for model training")
- parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
- parser.add_argument(
- "--size",
- type=str,
- default="128,256,512,1024",
- help="resolutions of images for the dataset",
- )
- parser.add_argument(
- "--n_worker",
- type=int,
- default=8,
- help="number of workers for preparing dataset",
- )
- parser.add_argument(
- "--resample",
- type=str,
- default="lanczos",
- help="resampling methods for resizing images",
- )
- parser.add_argument("path", type=str, help="path to the image dataset")
-
- args = parser.parse_args()
-
- resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
- resample = resample_map[args.resample]
-
- sizes = [int(s.strip()) for s in args.size.split(",")]
-
- print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
-
- if os.path.isdir(args.path):
- files = find_images(args.path)
- else:
- with open(args.path, 'r') as f:
- files = [(i,line.strip()) for i,line in enumerate(f.readlines())]
- print(f"Number of images: {len(files)}")
- with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
- prepare(env, files, args.n_worker, sizes=sizes, resample=resample)
|