diff --git a/ms_adapter/pytorch/utils/__init__.py b/ms_adapter/pytorch/utils/__init__.py index 49f9febe..316b9626 100644 --- a/ms_adapter/pytorch/utils/__init__.py +++ b/ms_adapter/pytorch/utils/__init__.py @@ -1 +1 @@ -# from ms_adapter.pytorch.utils import data +from ms_adapter.pytorch.utils import data diff --git a/ms_adapter/torchvision/__init__.py b/ms_adapter/torchvision/__init__.py index caa0d46a..bf29e4f2 100644 --- a/ms_adapter/torchvision/__init__.py +++ b/ms_adapter/torchvision/__init__.py @@ -2,12 +2,12 @@ import os import warnings # import torch -# from torchvision import datasets -# from torchvision import io -# from torchvision import models -# from torchvision import ops +from ms_adapter.torchvision import datasets +from ms_adapter.torchvision import io +# from ms_adapter.torchvision import models +# from ms_adapter.torchvision import ops from ms_adapter.torchvision import transforms -# from torchvision import utils +# from ms_adapter.torchvision import utils from .extension import _HAS_OPS diff --git a/ms_adapter/torchvision/_internally_replaced_utils.py b/ms_adapter/torchvision/_internally_replaced_utils.py index 58b93585..0c5cefd0 100644 --- a/ms_adapter/torchvision/_internally_replaced_utils.py +++ b/ms_adapter/torchvision/_internally_replaced_utils.py @@ -1,58 +1,68 @@ -# import importlib.machinery -# import os -# -# from torch.hub import _get_torch_home -# -# -# _HOME = os.path.join(_get_torch_home(), "datasets", "vision") -# _USE_SHARDED_DATASETS = False -# -# -# def _download_file_from_remote_location(fpath: str, url: str) -> None: -# pass -# -# -# def _is_remote_location_available() -> bool: -# return False -# -# +import importlib.machinery +import os + + + +ENV_TORCH_HOME = 'TORCH_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv(ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, + DEFAULT_CACHE_DIR), 'torch'))) + return torch_home + +_HOME = os.path.join(_get_torch_home(), "datasets", "vision") +_USE_SHARDED_DATASETS = False + + + +def _download_file_from_remote_location(fpath: str, url: str) -> None: + pass + + +def _is_remote_location_available() -> bool: + return False + + # try: # from torch.hub import load_state_dict_from_url # noqa: 401 # except ImportError: # from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401 -# -# -# def _get_extension_path(lib_name): -# -# lib_dir = os.path.dirname(__file__) -# if os.name == "nt": -# # Register the main torchvision library location on the default DLL path -# import ctypes -# import sys -# -# kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) -# with_load_library_flags = hasattr(kernel32, "AddDllDirectory") -# prev_error_mode = kernel32.SetErrorMode(0x0001) -# -# if with_load_library_flags: -# kernel32.AddDllDirectory.restype = ctypes.c_void_p -# -# if sys.version_info >= (3, 8): -# os.add_dll_directory(lib_dir) -# elif with_load_library_flags: -# res = kernel32.AddDllDirectory(lib_dir) -# if res is None: -# err = ctypes.WinError(ctypes.get_last_error()) -# err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' -# raise err -# -# kernel32.SetErrorMode(prev_error_mode) -# -# loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) -# -# extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) -# ext_specs = extfinder.find_spec(lib_name) -# if ext_specs is None: -# raise ImportError -# -# return ext_specs.origin + + +def _get_extension_path(lib_name): + + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(lib_name) + if ext_specs is None: + raise ImportError + + return ext_specs.origin diff --git a/ms_adapter/torchvision/datasets/_optical_flow.py b/ms_adapter/torchvision/datasets/_optical_flow.py index ce38c27c..4017d493 100644 --- a/ms_adapter/torchvision/datasets/_optical_flow.py +++ b/ms_adapter/torchvision/datasets/_optical_flow.py @@ -6,7 +6,8 @@ from glob import glob from pathlib import Path import numpy as np -import torch +# import torch +import ms_adapter.pytorch as torch from PIL import Image from ..io.image import _read_png_16 @@ -27,12 +28,12 @@ class FlowDataset(ABC, VisionDataset): # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow), # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. - _has_builtin_flow_mask = False def __init__(self, root, transforms=None): super().__init__(root=root) self.transforms = transforms + self._has_builtin_flow_mask = False self._flow_list = [] self._image_list = [] @@ -115,8 +116,6 @@ class Sintel(FlowDataset): details on the different passes. transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. - ``valid_flow_mask`` is expected for consistency with other datasets which - return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", pass_name="clean", transforms=None): @@ -179,13 +178,12 @@ class KittiFlow(FlowDataset): ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ - _has_builtin_flow_mask = True def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) - + self._has_builtin_flow_mask = True root = Path(root) / "KittiFlow" / (split + "ing") images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) @@ -242,8 +240,6 @@ class FlyingChairs(FlowDataset): split (string, optional): The dataset split, either "train" (default) or "val" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. - ``valid_flow_mask`` is expected for consistency with other datasets which - return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", transforms=None): @@ -313,8 +309,6 @@ class FlyingThings3D(FlowDataset): camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. - ``valid_flow_mask`` is expected for consistency with other datasets which - return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): @@ -400,10 +394,10 @@ class HD1K(FlowDataset): ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ - _has_builtin_flow_mask = True def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) + self._has_builtin_flow_mask = True verify_str_arg(split, "split", valid_values=("train", "test")) diff --git a/ms_adapter/torchvision/datasets/caltech.py b/ms_adapter/torchvision/datasets/caltech.py index e95043ce..fab85299 100644 --- a/ms_adapter/torchvision/datasets/caltech.py +++ b/ms_adapter/torchvision/datasets/caltech.py @@ -1,5 +1,6 @@ import os import os.path +import warnings from typing import Any, Callable, List, Optional, Union, Tuple from PIL import Image @@ -127,7 +128,9 @@ class Caltech101(VisionDataset): if self._check_integrity(): print("Files already downloaded and verified") return - + warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, " + "please download it from the official website." + "Caltech 101 ") download_and_extract_archive( "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", self.root, @@ -229,9 +232,16 @@ class Caltech256(VisionDataset): print("Files already downloaded and verified") return + # download_and_extract_archive( + # "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", + # self.root, + # filename="256_ObjectCategories.tar", + # md5="67b4f42ca05d46448c6bb8ecd2220f6d", + # ) download_and_extract_archive( - "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", + "https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1", self.root, filename="256_ObjectCategories.tar", md5="67b4f42ca05d46448c6bb8ecd2220f6d", ) + diff --git a/ms_adapter/torchvision/datasets/celeba.py b/ms_adapter/torchvision/datasets/celeba.py index e9dd883b..cd6066b0 100644 --- a/ms_adapter/torchvision/datasets/celeba.py +++ b/ms_adapter/torchvision/datasets/celeba.py @@ -1,11 +1,12 @@ import csv import os +import warnings from collections import namedtuple from typing import Any, Callable, List, Optional, Union, Tuple import PIL -import torch - +import ms_adapter.pytorch as torch +import mindspore as ms from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .vision import VisionDataset @@ -40,22 +41,9 @@ class CelebA(VisionDataset): downloaded again. """ - base_folder = "celeba" # There currently does not appear to be a easy way to extract 7z in python (without introducing additional # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available # right now. - file_list = [ - # File ID MD5 Hash Filename - ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), - # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), - # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), - ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), - ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), - ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), - ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), - # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), - ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), - ] def __init__( self, @@ -67,6 +55,19 @@ class CelebA(VisionDataset): download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) + self.file_list = [ + # File ID MD5 Hash Filename + ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), + # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), + # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), + ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), + ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), + ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), + ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), + # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), + ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), + ] + self.base_folder = "celeba" self.split = split if isinstance(target_type, list): self.target_type = target_type @@ -94,17 +95,28 @@ class CelebA(VisionDataset): bbox = self._load_csv("list_bbox_celeba.txt", header=1) landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1) attr = self._load_csv("list_attr_celeba.txt", header=1) - + self.old_identity = identity + self.old_attr = attr mask = slice(None) if split_ is None else (splits.data == split_).squeeze() - if mask == slice(None): # if split == "all" self.filename = splits.index else: self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))] - self.identity = identity.data[mask] - self.bbox = bbox.data[mask] - self.landmarks_align = landmarks_align.data[mask] - self.attr = attr.data[mask] + if mask != slice(None): + self.identity = torch.tensor(ms.ops.masked_select(identity.data, ms.ops.expand_dims(mask, -1))).reshape(-1,identity.data.shape[1]) + self.bbox = torch.tensor(ms.ops.masked_select(bbox.data,ms.ops.expand_dims(mask, -1))).reshape(-1, bbox.data.shape[1]) + self.landmarks_align = torch.tensor(ms.ops.masked_select(landmarks_align.data, ms.ops.expand_dims(mask, -1))).reshape(-1,landmarks_align.data.shape[1]) + self.attr = torch.tensor(ms.ops.masked_select(attr.data, ms.ops.expand_dims(mask, -1))).reshape(-1,attr.data.shape[1]) + else: + self.identity = torch.tensor(identity.data) + self.bbox = torch.tensor(bbox.data) + self.landmarks_align = torch.tensor(landmarks_align.data) + self.attr = torch.tensor(attr.data) + # self.identity = identity.data[mask] + # self.bbox = bbox.data[mask] + # self.landmarks_align = landmarks_align.data[mask] + # self.attr = attr.data[mask] + # map from {-1, 1} to {0, 1} self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor") self.attr_names = attr.header @@ -145,6 +157,9 @@ class CelebA(VisionDataset): if self._check_integrity(): print("Files already downloaded and verified") return + warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, " + "please download it from the official website." + "Large-scale CelebFaces Attributes (CelebA) Dataset ") for (file_id, md5, filename) in self.file_list: download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) @@ -153,7 +168,12 @@ class CelebA(VisionDataset): def __getitem__(self, index: int) -> Tuple[Any, Any]: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) - + # print("self.landmarks_align", self.landmarks_align.shape) + # print("self.old_identity", self.old_identity.data.shape) + # print("self.bbox", self.bbox.shape) + # print("mask", self.mask.shape) + # print("old_attr.data", self.old_attr.data.shape) + # print("self.attr", self.attr.shape) target: Any = [] for t in self.target_type: if t == "attr": diff --git a/ms_adapter/torchvision/datasets/cifar.py b/ms_adapter/torchvision/datasets/cifar.py index adfb7437..9aa86020 100644 --- a/ms_adapter/torchvision/datasets/cifar.py +++ b/ms_adapter/torchvision/datasets/cifar.py @@ -78,8 +78,11 @@ class CIFAR10(VisionDataset): # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) + print("file_path", file_path) with open(file_path, "rb") as f: + print("==============================5") entry = pickle.load(f, encoding="latin1") + print("==============================6") self.data.append(entry["data"]) if "labels" in entry: self.targets.extend(entry["labels"]) diff --git a/ms_adapter/torchvision/datasets/fakedata.py b/ms_adapter/torchvision/datasets/fakedata.py index 244af634..69101972 100644 --- a/ms_adapter/torchvision/datasets/fakedata.py +++ b/ms_adapter/torchvision/datasets/fakedata.py @@ -1,9 +1,9 @@ from typing import Any, Callable, Optional, Tuple +import numpy as np +import ms_adapter.pytorch as torch -import torch - -from .. import transforms -from .vision import VisionDataset +from ms_adapter.torchvision import transforms +from ms_adapter.torchvision.datasets.vision import VisionDataset class FakeData(VisionDataset): @@ -25,7 +25,7 @@ class FakeData(VisionDataset): def __init__( self, size: int = 1000, - image_size: Tuple[int, int, int] = (3, 224, 224), + image_size: Tuple[int, int, int] = (224, 224, 3), num_classes: int = 10, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -48,14 +48,15 @@ class FakeData(VisionDataset): # create random image that is consistent with the index id if index >= len(self): raise IndexError(f"{self.__class__.__name__} index out of range") - rng_state = torch.get_rng_state() - torch.manual_seed(index + self.random_offset) - img = torch.randn(*self.image_size) - target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0] - torch.set_rng_state(rng_state) - + rng = np.random.RandomState() + rng_state = rng.get_state() + rng.seed(index + self.random_offset) + img = np.asarray(rng.randn(*self.image_size), dtype=np.uint8) + target = rng.randint(0, self.num_classes, size=(1,), dtype=np.long)[0] + rng.set_state(rng_state) # convert to PIL Image img = transforms.ToPILImage()(img) + if self.transform is not None: img = self.transform(img) if self.target_transform is not None: diff --git a/ms_adapter/torchvision/datasets/fer2013.py b/ms_adapter/torchvision/datasets/fer2013.py index 60cbfd9b..b80e29f1 100644 --- a/ms_adapter/torchvision/datasets/fer2013.py +++ b/ms_adapter/torchvision/datasets/fer2013.py @@ -2,7 +2,7 @@ import csv import pathlib from typing import Any, Callable, Optional, Tuple -import torch +import ms_adapter.pytorch as torch from PIL import Image from .utils import verify_str_arg, check_integrity diff --git a/ms_adapter/torchvision/datasets/hmdb51.py b/ms_adapter/torchvision/datasets/hmdb51.py index f7341f4a..e2f011b7 100644 --- a/ms_adapter/torchvision/datasets/hmdb51.py +++ b/ms_adapter/torchvision/datasets/hmdb51.py @@ -2,7 +2,6 @@ import glob import os from typing import Optional, Callable, Tuple, Dict, Any, List -from torch import Tensor from .folder import find_classes, make_dataset from .video_utils import VideoClips @@ -140,7 +139,7 @@ class HMDB51(VisionDataset): def __len__(self) -> int: return self.video_clips.num_clips() - def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + def __getitem__(self, idx: int): video, audio, _, video_idx = self.video_clips.get_clip(idx) sample_index = self.indices[video_idx] _, class_index = self.samples[sample_index] diff --git a/ms_adapter/torchvision/datasets/imagenet.py b/ms_adapter/torchvision/datasets/imagenet.py index a272bb86..326153ed 100644 --- a/ms_adapter/torchvision/datasets/imagenet.py +++ b/ms_adapter/torchvision/datasets/imagenet.py @@ -1,10 +1,10 @@ import os import shutil import tempfile +import pickle from contextlib import contextmanager from typing import Any, Dict, List, Iterator, Optional, Tuple -import torch from .folder import ImageFolder from .utils import check_integrity, extract_archive, verify_str_arg @@ -43,9 +43,10 @@ class ImageNet(ImageFolder): root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) - self.parse_archives() + self.parse_archives() #TODO save dataset into MindRecord wnid_to_classes = load_meta_file(self.root)[0] + # wnid_to_classes, self.val_wnid= parse_devkit_archive(self.root) super().__init__(self.split_folder, **kwargs) self.root = root @@ -78,7 +79,9 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str file = os.path.join(root, file) if check_integrity(file): - return torch.load(file) + with open(file, "rb") as f: + data = pickle.load(f) + return data else: msg = ( "The meta file {} is not present in the root directory or is corrupted. " @@ -96,7 +99,7 @@ def _verify_archive(root: str, file: str, md5: str) -> None: raise RuntimeError(msg.format(file, root)) -def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: +def parse_devkit_archive(root: str, file: Optional[str] = None): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. @@ -147,8 +150,10 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: val_idcs = parse_val_groundtruth_txt(devkit_root) val_wnids = [idx_to_wnid[idx] for idx in val_idcs] - torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) - + with open(os.path.join(root, META_FILE), "wb") as f: + pickle.dump((wnid_to_classes, val_wnids), f) + # torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) + # return wnid_to_classes, val_wnids def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None: """Parse the train images archive of the ImageNet2012 classification dataset and diff --git a/ms_adapter/torchvision/datasets/kinetics.py b/ms_adapter/torchvision/datasets/kinetics.py index 2ba5e508..6dde3804 100644 --- a/ms_adapter/torchvision/datasets/kinetics.py +++ b/ms_adapter/torchvision/datasets/kinetics.py @@ -8,8 +8,6 @@ from multiprocessing import Pool from os import path from typing import Any, Callable, Dict, Optional, Tuple -from torch import Tensor - from .folder import find_classes, make_dataset from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity from .video_utils import VideoClips @@ -238,7 +236,7 @@ class Kinetics(VisionDataset): def __len__(self) -> int: return self.video_clips.num_clips() - def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: + def __getitem__(self, idx: int): video, audio, info, video_idx = self.video_clips.get_clip(idx) label = self.samples[video_idx][1] diff --git a/ms_adapter/torchvision/datasets/mnist.py b/ms_adapter/torchvision/datasets/mnist.py index 9f9ec457..bad4a351 100644 --- a/ms_adapter/torchvision/datasets/mnist.py +++ b/ms_adapter/torchvision/datasets/mnist.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError import numpy as np -import torch +import ms_adapter.pytorch as torch from PIL import Image from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity @@ -91,9 +91,9 @@ class MNIST(VisionDataset): super().__init__(root, transform=transform, target_transform=target_transform) self.train = train # training set or test set - if self._check_legacy_exist(): - self.data, self.targets = self._load_legacy_data() - return + # if self._check_legacy_exist(): + # self.data, self.targets = self._load_legacy_data() + # return if download: self.download() @@ -103,20 +103,20 @@ class MNIST(VisionDataset): self.data, self.targets = self._load_data() - def _check_legacy_exist(self): - processed_folder_exists = os.path.exists(self.processed_folder) - if not processed_folder_exists: - return False - - return all( - check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) - ) - - def _load_legacy_data(self): - # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data - # directly. - data_file = self.training_file if self.train else self.test_file - return torch.load(os.path.join(self.processed_folder, data_file)) + # def _check_legacy_exist(self): + # processed_folder_exists = os.path.exists(self.processed_folder) + # if not processed_folder_exists: + # return False + # + # return all( + # check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + # ) + + # def _load_legacy_data(self): + # # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # # directly. + # data_file = self.training_file if self.train else self.test_file + # return torch.load(os.path.join(self.processed_folder, data_file)) def _load_data(self): image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" diff --git a/ms_adapter/torchvision/datasets/pcam.py b/ms_adapter/torchvision/datasets/pcam.py index 4f124674..4adf5a34 100644 --- a/ms_adapter/torchvision/datasets/pcam.py +++ b/ms_adapter/torchvision/datasets/pcam.py @@ -1,4 +1,5 @@ import pathlib +import warnings from typing import Any, Callable, Optional, Tuple from PIL import Image @@ -123,7 +124,9 @@ class PCAM(VisionDataset): def _download(self) -> None: if self._check_exists(): return - + warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, " + "please download it from the official website." + "PCAM Dataset ") for file_name, file_id, md5 in self._FILES[self._split].values(): archive_name = file_name + ".gz" download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5) diff --git a/ms_adapter/torchvision/datasets/phototour.py b/ms_adapter/torchvision/datasets/phototour.py index edf1d2ee..511c9665 100644 --- a/ms_adapter/torchvision/datasets/phototour.py +++ b/ms_adapter/torchvision/datasets/phototour.py @@ -2,7 +2,7 @@ import os from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np -import torch +import ms_adapter.pytorch as torch from PIL import Image from .utils import download_url @@ -93,7 +93,7 @@ class PhotoTour(VisionDataset): self.name = name self.data_dir = os.path.join(self.root, name) self.data_down = os.path.join(self.root, f"{name}.zip") - self.data_file = os.path.join(self.root, f"{name}.pt") + # self.data_file = os.path.join(self.root, f"{name}.pt") self.train = train self.mean = self.means[name] @@ -102,11 +102,11 @@ class PhotoTour(VisionDataset): if download: self.download() - if not self._check_datafile_exists(): - self.cache() - + # if not self._check_datafile_exists(): + # self.cache() + self.data, self.labels, self.matches = self.cache() # load the serialized data - self.data, self.labels, self.matches = torch.load(self.data_file) + # self.data, self.labels, self.matches = torch.load(self.data_file) def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]: """ @@ -131,16 +131,16 @@ class PhotoTour(VisionDataset): def __len__(self) -> int: return len(self.data if self.train else self.matches) - def _check_datafile_exists(self) -> bool: - return os.path.exists(self.data_file) + # def _check_datafile_exists(self) -> bool: + # return os.path.exists(self.data_file) def _check_downloaded(self) -> bool: return os.path.exists(self.data_dir) def download(self) -> None: - if self._check_datafile_exists(): - print(f"# Found cached data {self.data_file}") - return + # if self._check_datafile_exists(): + # print(f"# Found cached data {self.data_file}") + # return if not self._check_downloaded(): # download files @@ -160,18 +160,21 @@ class PhotoTour(VisionDataset): os.unlink(fpath) - def cache(self) -> None: + def cache(self): # process and save as torch files - print(f"# Caching data {self.data_file}") - - dataset = ( - read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), - read_info_file(self.data_dir, self.info_file), - read_matches_files(self.data_dir, self.matches_files), - ) - - with open(self.data_file, "wb") as f: - torch.save(dataset, f) + # print(f"# Caching data {self.data_file}") + + # dataset = ( + # read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), + # read_info_file(self.data_dir, self.info_file), + # read_matches_files(self.data_dir, self.matches_files), + # ) + data = read_image_file(self.data_dir, self.image_ext, self.lens[self.name]) + info = read_info_file(self.data_dir, self.info_file) + matches = read_matches_files(self.data_dir, self.matches_files) + return data, info, matches + # with open(self.data_file, "wb") as f: + # torch.save(dataset, f) def extra_repr(self) -> str: split = "Train" if self.train is True else "Test" diff --git a/ms_adapter/torchvision/datasets/samplers/clip_sampler.py b/ms_adapter/torchvision/datasets/samplers/clip_sampler.py index cb31919a..9737a8a7 100644 --- a/ms_adapter/torchvision/datasets/samplers/clip_sampler.py +++ b/ms_adapter/torchvision/datasets/samplers/clip_sampler.py @@ -88,7 +88,7 @@ class DistributedSampler(Sampler): # subsample indices = indices[self.rank : total_group_size : self.num_replicas, :] - indices = torch.reshape(indices, (-1,)).tolist() + indices = torch.reshape(indices, (-1,)).numpy().tolist() assert len(indices) == self.num_samples if isinstance(self.dataset, Sampler): @@ -134,7 +134,7 @@ class UniformClipSampler(Sampler): sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64) s += length idxs.append(sampled) - return iter(cast(List[int], torch.cat(idxs).tolist())) + return iter(cast(List[int], torch.cat(idxs).numpy().tolist())) def __len__(self) -> int: return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0) diff --git a/ms_adapter/torchvision/datasets/video_utils.py b/ms_adapter/torchvision/datasets/video_utils.py index 3fdd50d1..057a5fa4 100644 --- a/ms_adapter/torchvision/datasets/video_utils.py +++ b/ms_adapter/torchvision/datasets/video_utils.py @@ -4,10 +4,10 @@ import warnings from fractions import Fraction from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast -import torch -from torchvision.io import ( - _probe_video_from_file, - _read_video_from_file, +import ms_adapter.pytorch as torch +from ms_adapter.torchvision.io import ( + # _probe_video_from_file, + # _read_video_from_file, read_video, read_video_timestamps, ) @@ -145,9 +145,9 @@ class VideoClips: # strategy: use a DataLoader to parallelize read_video_timestamps # so need to create a dummy dataset first - import torch.utils.data + import ms_adapter.pytorch.utils.data - dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader( + dl: ms_adapter.pytorch.utils.data.DataLoader = ms_adapter.pytorch.utils.data.DataLoader( _VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type] batch_size=16, num_workers=self.num_workers, @@ -252,7 +252,7 @@ class VideoClips: self.clips.append(clips) self.resampling_idxs.append(idxs) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) - self.cumulative_sizes = clip_lengths.cumsum(0).tolist() + self.cumulative_sizes = clip_lengths.cumsum(0).asnumpy().tolist() def __len__(self) -> int: return self.num_clips() @@ -286,7 +286,7 @@ class VideoClips: # advanced indexing step = int(step) return slice(None, None, step) - idxs = torch.arange(num_frames, dtype=torch.float32) * step + idxs = torch.arange(start=0, end = num_frames, dtype=torch.float32) * step idxs = idxs.floor().to(torch.int64) return idxs @@ -330,39 +330,39 @@ class VideoClips: start_pts = clip_pts[0].item() end_pts = clip_pts[-1].item() video, audio, info = read_video(video_path, start_pts, end_pts) - else: - _info = _probe_video_from_file(video_path) - video_fps = _info.video_fps - audio_fps = None - - video_start_pts = cast(int, clip_pts[0].item()) - video_end_pts = cast(int, clip_pts[-1].item()) - - audio_start_pts, audio_end_pts = 0, -1 - audio_timebase = Fraction(0, 1) - video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) - if _info.has_audio: - audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) - audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) - audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) - audio_fps = _info.audio_sample_rate - video, audio, _ = _read_video_from_file( - video_path, - video_width=self._video_width, - video_height=self._video_height, - video_min_dimension=self._video_min_dimension, - video_max_dimension=self._video_max_dimension, - video_pts_range=(video_start_pts, video_end_pts), - video_timebase=video_timebase, - audio_samples=self._audio_samples, - audio_channels=self._audio_channels, - audio_pts_range=(audio_start_pts, audio_end_pts), - audio_timebase=audio_timebase, - ) - - info = {"video_fps": video_fps} - if audio_fps is not None: - info["audio_fps"] = audio_fps + # else: #TODO just support pyav + # _info = _probe_video_from_file(video_path) + # video_fps = _info.video_fps + # audio_fps = None + # + # video_start_pts = cast(int, clip_pts[0].item()) + # video_end_pts = cast(int, clip_pts[-1].item()) + # + # audio_start_pts, audio_end_pts = 0, -1 + # audio_timebase = Fraction(0, 1) + # video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator) + # if _info.has_audio: + # audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator) + # audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) + # audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) + # audio_fps = _info.audio_sample_rate + # video, audio, _ = _read_video_from_file( + # video_path, + # video_width=self._video_width, + # video_height=self._video_height, + # video_min_dimension=self._video_min_dimension, + # video_max_dimension=self._video_max_dimension, + # video_pts_range=(video_start_pts, video_end_pts), + # video_timebase=video_timebase, + # audio_samples=self._audio_samples, + # audio_channels=self._audio_channels, + # audio_pts_range=(audio_start_pts, audio_end_pts), + # audio_timebase=audio_timebase, + # ) + # + # info = {"video_fps": video_fps} + # if audio_fps is not None: + # info["audio_fps"] = audio_fps if self.frame_rate is not None: resampling_idx = self.resampling_idxs[video_idx][clip_idx] diff --git a/ms_adapter/torchvision/extension.py b/ms_adapter/torchvision/extension.py index 1db23ca4..55168341 100644 --- a/ms_adapter/torchvision/extension.py +++ b/ms_adapter/torchvision/extension.py @@ -2,10 +2,7 @@ import ctypes import os import sys from warnings import warn -# -# import torch -# -# from ._internally_replaced_utils import _get_extension_path +from ._internally_replaced_utils import _get_extension_path # # _HAS_OPS = False diff --git a/ms_adapter/torchvision/io/__init__.py b/ms_adapter/torchvision/io/__init__.py index 22788cef..58339283 100644 --- a/ms_adapter/torchvision/io/__init__.py +++ b/ms_adapter/torchvision/io/__init__.py @@ -1,9 +1,4 @@ from typing import Any, Dict, Iterator - -import torch - -from ..utils import _log_api_usage_once - try: from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER except ModuleNotFoundError: @@ -37,7 +32,7 @@ from .video import ( read_video_timestamps, write_video, ) -from .video_reader import VideoReader +# from .video_reader import VideoReader __all__ = [ @@ -52,8 +47,8 @@ __all__ = [ "_probe_video_from_memory", "_HAS_VIDEO_OPT", "_HAS_GPU_VIDEO_DECODER", - "_read_video_clip_from_memory", - "_read_video_meta_data", + # "_read_video_clip_from_memory", # TODO depend on torch.ops + # "_read_video_meta_data",# TODO depend on torch.ops "VideoMetaData", "Timebase", "ImageReadMode", @@ -67,6 +62,6 @@ __all__ = [ "write_file", "write_jpeg", "write_png", - "Video", - "VideoReader", + # "Video", + # "VideoReader", # TODO depend on torch.ops ] diff --git a/ms_adapter/torchvision/io/_load_gpu_decoder.py b/ms_adapter/torchvision/io/_load_gpu_decoder.py index f7869f0a..1225199a 100644 --- a/ms_adapter/torchvision/io/_load_gpu_decoder.py +++ b/ms_adapter/torchvision/io/_load_gpu_decoder.py @@ -1,8 +1,9 @@ -from ..extension import _load_library +# from ..extension import _load_library -try: - _load_library("Decoder") - _HAS_GPU_VIDEO_DECODER = True -except (ImportError, OSError): - _HAS_GPU_VIDEO_DECODER = False +# try: +# _load_library("Decoder") +# _HAS_GPU_VIDEO_DECODER = True +# except (ImportError, OSError): +# _HAS_GPU_VIDEO_DECODER = False +_HAS_GPU_VIDEO_DECODER = False \ No newline at end of file diff --git a/ms_adapter/torchvision/io/_video_opt.py b/ms_adapter/torchvision/io/_video_opt.py index 055b195a..1ce1dd70 100644 --- a/ms_adapter/torchvision/io/_video_opt.py +++ b/ms_adapter/torchvision/io/_video_opt.py @@ -3,17 +3,17 @@ import warnings from fractions import Fraction from typing import List, Tuple, Dict, Optional, Union -import torch -from ..extension import _load_library +import ms_adapter.pytorch as torch +# from ..extension import _load_library -try: - _load_library("video_reader") - _HAS_VIDEO_OPT = True -except (ImportError, OSError): - _HAS_VIDEO_OPT = False - +# try: +# _load_library("video_reader") +# _HAS_VIDEO_OPT = True +# except (ImportError, OSError): +# _HAS_VIDEO_OPT = False +_HAS_VIDEO_OPT = False default_timebase = Fraction(0, 1) diff --git a/ms_adapter/torchvision/io/image.py b/ms_adapter/torchvision/io/image.py index 17482375..e1b489b4 100644 --- a/ms_adapter/torchvision/io/image.py +++ b/ms_adapter/torchvision/io/image.py @@ -1,16 +1,13 @@ from enum import Enum from warnings import warn -import torch +import ms_adapter.pytorch as torch +from mindspore.dataset import vision -from ..extension import _load_library -from ..utils import _log_api_usage_once - - -try: - _load_library("image") -except (ImportError, OSError) as e: - warn(f"Failed to load image Python extension: {e}") +# try: +# _load_library("image") +# except (ImportError, OSError) as e: +# warn(f"Failed to load image Python extension: {e}") class ImageReadMode(Enum): @@ -42,9 +39,9 @@ def read_file(path: str) -> torch.Tensor: Returns: data (Tensor) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_file) - data = torch.ops.image.read_file(path) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(read_file) + data = vision.read_file(path) return data @@ -57,9 +54,9 @@ def write_file(filename: str, data: torch.Tensor) -> None: filename (str): the path to the file to be written data (Tensor): the contents to be written to the output file """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_file) - torch.ops.image.write_file(filename, data) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(write_file) + vision.write_file(filename, data) def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: @@ -79,9 +76,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE Returns: output (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_png) - output = torch.ops.image.decode_png(input, mode.value, False) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(decode_png) + # output = vision.decode_png(input, mode.value, False) + output = vision.Decode()(input) + output = torch.tensor(output, dtype=torch.uint8) + output = output.permute(2, 0, 1) return output @@ -100,9 +100,9 @@ def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor: Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the PNG file. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(encode_png) - output = torch.ops.image.encode_png(input, compression_level) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(encode_png) + output = vision.encode_png(input, compression_level) return output @@ -118,8 +118,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): compression_level (int): Compression factor for the resulting file, it must be a number between 0 and 9. Default: 6 """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_png) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(write_png) output = encode_png(input, compression_level) write_file(filename, output) @@ -154,13 +154,16 @@ def decode_jpeg( Returns: output (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_jpeg) - device = torch.device(device) - if device.type == "cuda": - output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) - else: - output = torch.ops.image.decode_jpeg(input, mode.value) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(decode_jpeg) + # device = torch.device(device) + # if device.type == "cuda": + # output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) + # else: + # output = torch.ops.image.decode_jpeg(input, mode.value) + output = vision.Decode()(input) + output = torch.tensor(output, dtype=torch.uint8) + output = output.permute(2, 0, 1) return output @@ -179,12 +182,12 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the JPEG file. """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(encode_jpeg) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(encode_jpeg) if quality < 1 or quality > 100: raise ValueError("Image quality should be a positive number between 1 and 100") - output = torch.ops.image.encode_jpeg(input, quality) + output = vision.encode_jpeg(input, quality) return output @@ -199,8 +202,8 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): quality (int): Quality of the resulting JPEG file, it must be a number between 1 and 100. Default: 75 """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_jpeg) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(write_jpeg) output = encode_jpeg(input, quality) write_file(filename, output) @@ -224,9 +227,12 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN Returns: output (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_image) - output = torch.ops.image.decode_image(input, mode.value) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(decode_image) + # output = vision.decode_image(input, mode.value) + output = vision.Decode()(input) + output = torch.tensor(output, dtype=torch.uint8) + output = output.permute(2, 0, 1) return output @@ -246,12 +252,16 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc Returns: output (Tensor[image_channels, image_height, image_width]) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_image) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(read_image) data = read_file(path) return decode_image(data, mode) def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: data = read_file(path) - return torch.ops.image.decode_png(data, mode.value, True) + # return vision.decode_png(data, mode.value, True) + output = vision.Decode()(data) + output = torch.tensor(output, dtype=torch.uint8) + output = output.permute(2, 0, 1) + return output \ No newline at end of file diff --git a/ms_adapter/torchvision/io/video.py b/ms_adapter/torchvision/io/video.py index ceb20fe5..34f8bdd0 100644 --- a/ms_adapter/torchvision/io/video.py +++ b/ms_adapter/torchvision/io/video.py @@ -7,9 +7,8 @@ from fractions import Fraction from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -import torch +import ms_adapter.pytorch as torch -from ..utils import _log_api_usage_once from . import _video_opt @@ -78,8 +77,8 @@ def write_video( audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(write_video) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(write_video) _check_av_available() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() @@ -260,14 +259,14 @@ def read_video( aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video) + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(read_video) output_format = output_format.upper() if output_format not in ("THWC", "TCHW"): raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") - from torchvision import get_video_backend + from ms_adapter.torchvision import get_video_backend if not os.path.exists(filename): raise RuntimeError(f"File not found: {filename}") @@ -381,9 +380,9 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in video_fps (float, optional): the frame rate for the video """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(read_video_timestamps) - from torchvision import get_video_backend + # if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + # _log_api_usage_once(read_video_timestamps) + from ms_adapter.torchvision import get_video_backend if get_video_backend() != "pyav": return _video_opt._read_video_timestamps(filename, pts_unit) diff --git a/ms_adapter/torchvision/io/video_reader.py b/ms_adapter/torchvision/io/video_reader.py index 881b9d75..f4dbb09f 100644 --- a/ms_adapter/torchvision/io/video_reader.py +++ b/ms_adapter/torchvision/io/video_reader.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator import torch -from ..utils import _log_api_usage_once + try: from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER @@ -91,7 +91,7 @@ class VideoReader: """ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None: - _log_api_usage_once(self) + # _log_api_usage_once(self) self.is_cuda = False device = torch.device(device) if device.type == "cuda": diff --git a/testing/ut/torchvision/common_utils.py b/testing/ut/torchvision/common_utils.py index 8e608cce..e29567e2 100644 --- a/testing/ut/torchvision/common_utils.py +++ b/testing/ut/torchvision/common_utils.py @@ -10,7 +10,7 @@ import ms_adapter.pytorch as torch from ms_adapter.pytorch.tensor import cast_to_adapter_tensor from PIL import Image -# from torchvision import io +from ms_adapter.torchvision import io import __main__ # noqa: 401 @@ -143,23 +143,24 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu # assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) -# def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): -# names = [] -# for i in range(num_videos): -# if sizes is None: -# size = 5 * (i + 1) -# else: -# size = sizes[i] -# if fps is None: -# f = 5 -# else: -# f = fps[i] -# data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) -# name = os.path.join(tmpdir, f"{i}.mp4") -# names.append(name) -# io.write_video(name, data, fps=f) -# -# return names +def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None): + names = [] + for i in range(num_videos): + print(i) + if sizes is None: + size = 5 * (i + 1) + else: + size = sizes[i] + if fps is None: + f = 5 + else: + f = fps[i] + data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8) + name = os.path.join(tmpdir, f"{i}.mp4") + names.append(name) + io.write_video(name, data, fps=f) + + return names def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): diff --git a/testing/ut/torchvision/datasets_utils.py b/testing/ut/torchvision/datasets_utils.py new file mode 100644 index 00000000..a83efe4d --- /dev/null +++ b/testing/ut/torchvision/datasets_utils.py @@ -0,0 +1,982 @@ +import contextlib +import functools +import importlib +import inspect +import itertools +import os +import pathlib +import random +import shutil +import string +import struct +import tarfile +import unittest +import unittest.mock +import zipfile +from collections import defaultdict +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +import numpy as np + +import PIL +import PIL.Image +import pytest +import ms_adapter.pytorch as torch +from ms_adapter.torchvision.datasets import VisionDataset +from ms_adapter.torchvision.io import write_video +from common_utils import disable_console_output, get_tmp_dir +from ms_adapter.torchvision.transforms.functional import get_dimensions + +__all__ = [ + "UsageError", + "lazy_importer", + "test_all_configs", + "DatasetTestCase", + "ImageDatasetTestCase", + "VideoDatasetTestCase", + "create_image_or_video_tensor", + "create_image_file", + "create_image_folder", + "create_video_file", + "create_video_folder", + "make_tar", + "make_zip", + "create_random_string", +] + + +class UsageError(Exception): + """Should be raised in case an error happens in the setup rather than the test.""" + + +class LazyImporter: + r"""Lazy importer for additional dependencies. + + Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class + provide modules listed in MODULES as attributes. They are only imported when accessed. + + """ + MODULES = ( + "av", + "lmdb", + "pycocotools", + "requests", + "scipy.io", + "scipy.sparse", + "h5py", + ) + + def __init__(self): + modules = defaultdict(list) + for module in self.MODULES: + module, *submodules = module.split(".", 1) + if submodules: + modules[module].append(submodules[0]) + else: + # This introduces the module so that it is known when we later iterate over the dictionary. + modules.__missing__(module) + + for module, submodules in modules.items(): + # We need the quirky 'module=module' and submodules=submodules arguments to the lambda since otherwise the + # lookup for these would happen at runtime rather than at definition. Thus, without it, every property + # would try to import the last item in 'modules' + setattr( + type(self), + module, + property(lambda self, module=module, submodules=submodules: LazyImporter._import(module, submodules)), + ) + + @staticmethod + def _import(package, subpackages): + try: + module = importlib.import_module(package) + except ImportError as error: + raise UsageError( + f"Failed to import module '{package}'. " + f"This probably means that the current test case needs '{package}' installed, " + f"but it is not a dependency of torchvision. " + f"You need to install it manually, for example 'pip install {package}'." + ) from error + + for name in subpackages: + importlib.import_module(f".{name}", package=package) + + return module + + +lazy_importer = LazyImporter() + + +def requires_lazy_imports(*modules): + def outer_wrapper(fn): + @functools.wraps(fn) + def inner_wrapper(*args, **kwargs): + for module in modules: + getattr(lazy_importer, module.replace(".", "_")) + return fn(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + +def test_all_configs(test): + """Decorator to run test against all configurations. + + Add this as decorator to an arbitrary test to run it against all configurations. This includes + :attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`. + + The current configuration is provided as the first parameter for the test: + + .. code-block:: + + @test_all_configs() + def test_foo(self, config): + pass + + .. note:: + + This will try to remove duplicate configurations. During this process it will not not preserve a potential + ordering of the configurations or an inner ordering of a configuration. + """ + + def maybe_remove_duplicates(configs): + try: + return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}] + except TypeError: + # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate + # removal would be a lot more elaborate and we simply bail out. + return configs + + @functools.wraps(test) + def wrapper(self): + configs = [] + if self.DEFAULT_CONFIG is not None: + configs.append(self.DEFAULT_CONFIG) + if self.ADDITIONAL_CONFIGS is not None: + configs.extend(self.ADDITIONAL_CONFIGS) + + if not configs: + configs = [self._KWARG_DEFAULTS.copy()] + else: + configs = maybe_remove_duplicates(configs) + + for config in configs: + with self.subTest(**config): + test(self, config) + + return wrapper + + +def combinations_grid(**kwargs): + """Creates a grid of input combinations. + + Each element in the returned sequence is a dictionary containing one possible combination as values. + + Example: + >>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham")) + [ + {'foo': 'bar', 'spam': 'eggs'}, + {'foo': 'bar', 'spam': 'ham'}, + {'foo': 'baz', 'spam': 'eggs'}, + {'foo': 'baz', 'spam': 'ham'} + ] + """ + return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] + + +class DatasetTestCase(unittest.TestCase): + """Abstract base class for all dataset testcases. + + You have to overwrite the following class attributes: + + - DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested. + - FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of + providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to + get a reasonable default, that should work for most cases. Each entry of the sequence may be a tuple, + to indicate multiple possible values. + + Optionally, you can overwrite the following class attributes: + + - DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all + keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and + ``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does + not provide one. + - ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can + contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``, + ``transforms``, or ``download``. + - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not + available, the tests are skipped. + + Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. + The fake data should resemble the original data as close as necessary, while containing only few examples. During + the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are + disabled. + + Without further configuration, the testcase will test if + + 1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or + corrupted, + 2. the dataset inherits from `torchvision.datasets.VisionDataset`, + 3. the dataset can be turned into a string, + 4. the feature types of a returned example matches ``FEATURE_TYPES``, + 5. the number of examples matches the injected fake data, and + 6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data. + + Case 3. to 6. are tested against all configurations in ``CONFIGS``. + + To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix: + + .. code-block:: + + def test_foo(self): + pass + + If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and + accept a single argument: + + .. code-block:: + + @test_all_configs + def test_bar(self, config): + pass + + Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional + information provided by the ``ìnject_fake_data()`` method: + + .. code-block:: + + def test_baz(self): + with self.create_dataset() as (dataset, info): + pass + """ + + DATASET_CLASS = None + FEATURE_TYPES = None + + DEFAULT_CONFIG = None + ADDITIONAL_CONFIGS = None + REQUIRED_PACKAGES = None + + # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. + _TRANSFORM_KWARGS = { + "transform", + "target_transform", + "transforms", + } + # These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS. + _SPECIAL_KWARGS = { + *_TRANSFORM_KWARGS, + "download", + } + + # These fields are populated during setupClass() within _populate_private_class_attributes() + + # This will be a dictionary containing all keyword arguments with their respective default values extracted from + # the dataset constructor. + _KWARG_DEFAULTS = None + # This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes. + _HAS_SPECIAL_KWARG = None + + # These functions are disabled during dataset creation in create_dataset(). + _CHECK_FUNCTIONS = { + "check_md5", + "check_integrity", + } + _DOWNLOAD_EXTRACT_FUNCTIONS = { + "download_url", + "download_file_from_google_drive", + "extract_archive", + "download_and_extract_archive", + } + + def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]: + """Define positional arguments passed to the dataset. + + .. note:: + + The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter. + Otherwise you need to overwrite this method. + + Args: + tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset + to be created and in turn also for the fake data injected here. + config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least + fields for all dataset parameters with default values. + + Returns: + (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets. + """ + return (tmpdir,) + + def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]: + """Inject fake data for dataset into a temporary directory. + + During the creation of the dataset the download and extract logic is disabled. Thus, the fake data injected + here needs to resemble the raw data, i.e. the state of the dataset directly after the files are downloaded and + potentially extracted. + + Args: + tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset + to be created and in turn also for the fake data injected here. + config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least + fields for all dataset parameters with default values. + + Needs to return one of the following: + + 1. (int): Number of examples in the dataset to be created, or + 2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field + ``"num_examples"`` that corresponds to the number of examples in the dataset to be created. + """ + raise NotImplementedError("You need to provide fake data in order for the tests to run.") + + @contextlib.contextmanager + def create_dataset( + self, + config: Optional[Dict[str, Any]] = None, + inject_fake_data: bool = True, + patch_checks: Optional[bool] = None, + **kwargs: Any, + ): + r"""Create the dataset in a temporary directory. + + The configuration passed to the dataset is populated to contain at least all parameters with default values. + For this the following order of precedence is used: + + 1. Parameters in :attr:`kwargs`. + 2. Configuration in :attr:`config`. + 3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`. + 4. Default parameters of the dataset. + + Args: + config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. + inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before + creating the dataset. + patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If + omitted defaults to the same value as ``inject_fake_data``. + **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they + overlap with ``config``. + + Yields: + dataset (torchvision.dataset.VisionDataset): Dataset. + info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` + for details. + """ + if patch_checks is None: + patch_checks = inject_fake_data + + special_kwargs, other_kwargs = self._split_kwargs(kwargs) + + complete_config = self._KWARG_DEFAULTS.copy() + if self.DEFAULT_CONFIG: + complete_config.update(self.DEFAULT_CONFIG) + if config: + complete_config.update(config) + if other_kwargs: + complete_config.update(other_kwargs) + + if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False): + # override download param to False param if its default is truthy + special_kwargs["download"] = False + + patchers = self._patch_download_extract() + if patch_checks: + patchers.update(self._patch_checks()) + with get_tmp_dir() as tmpdir: + args = self.dataset_args(tmpdir, complete_config) + info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None + with self._maybe_apply_patches(patchers), disable_console_output(): + dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) + + yield dataset, info + + @classmethod + def setUpClass(cls): + cls._verify_required_public_class_attributes() + cls._populate_private_class_attributes() + cls._process_optional_public_class_attributes() + super().setUpClass() + + @classmethod + def _verify_required_public_class_attributes(cls): + if cls.DATASET_CLASS is None: + raise UsageError( + "The class attribute 'DATASET_CLASS' needs to be overwritten. " + "It should contain the class of the dataset to be tested." + ) + if cls.FEATURE_TYPES is None: + raise UsageError( + "The class attribute 'FEATURE_TYPES' needs to be overwritten. " + "It should contain a sequence of types that the dataset returns when accessed by index." + ) + + @classmethod + def _populate_private_class_attributes(cls): + defaults = [] + for cls_ in cls.DATASET_CLASS.__mro__: + if cls_ is VisionDataset: + break + + argspec = inspect.getfullargspec(cls_.__init__) + + if not argspec.defaults: + continue + + defaults.append( + { + kwarg: default + for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults) + if not kwarg.startswith("_") + } + ) + + if not argspec.varkw: + break + + kwarg_defaults = dict() + for config in reversed(defaults): + kwarg_defaults.update(config) + + has_special_kwargs = set() + for name in cls._SPECIAL_KWARGS: + if name not in kwarg_defaults: + continue + + del kwarg_defaults[name] + has_special_kwargs.add(name) + + cls._KWARG_DEFAULTS = kwarg_defaults + cls._HAS_SPECIAL_KWARG = has_special_kwargs + + @classmethod + def _process_optional_public_class_attributes(cls): + def check_config(config, name): + special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config) + if special_kwargs: + raise UsageError( + f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. " + f"These are handled separately by the test case and should not be set here. " + f"If you need to test some custom behavior regarding these parameters, " + f"you need to write a custom test (*not* test case), e.g. test_custom_transform()." + ) + + if cls.DEFAULT_CONFIG is not None: + check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG") + + if cls.ADDITIONAL_CONFIGS is not None: + for idx, config in enumerate(cls.ADDITIONAL_CONFIGS): + check_config(config, f"CONFIGS[{idx}]") + + if cls.REQUIRED_PACKAGES: + missing_pkgs = [] + for pkg in cls.REQUIRED_PACKAGES: + try: + importlib.import_module(pkg) + except ImportError: + missing_pkgs.append(f"'{pkg}'") + + if missing_pkgs: + raise unittest.SkipTest( + f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset " + f"'{cls.DATASET_CLASS.__name__}', but are not installed." + ) + + def _split_kwargs(self, kwargs): + special_kwargs = kwargs.copy() + other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS} + return special_kwargs, other_kwargs + + def _inject_fake_data(self, tmpdir, config): + info = self.inject_fake_data(tmpdir, config) + if info is None: + raise UsageError( + "The method 'inject_fake_data' needs to return at least an integer indicating the number of " + "examples for the current configuration." + ) + elif isinstance(info, int): + info = dict(num_examples=info) + elif not isinstance(info, dict): + raise UsageError( + f"The additional information returned by the method 'inject_fake_data' must be either an " + f"integer indicating the number of examples for the current configuration or a dictionary with " + f"the same content. Got {type(info)} instead." + ) + elif "num_examples" not in info: + raise UsageError( + "The information dictionary returned by the method 'inject_fake_data' must contain a " + "'num_examples' field that holds the number of examples for the current configuration." + ) + return info + + def _patch_download_extract(self): + module = inspect.getmodule(self.DATASET_CLASS).__name__ + return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS} + + def _patch_checks(self): + module = inspect.getmodule(self.DATASET_CLASS).__name__ + return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS} + + @contextlib.contextmanager + def _maybe_apply_patches(self, patchers): + with contextlib.ExitStack() as stack: + mocks = {} + for patcher in patchers: + with contextlib.suppress(AttributeError): + mocks[patcher.target] = stack.enter_context(patcher) + yield mocks + + def test_not_found_or_corrupted(self): + with pytest.raises((FileNotFoundError, RuntimeError)): + with self.create_dataset(inject_fake_data=False): + pass + + def test_smoke(self): + with self.create_dataset() as (dataset, _): + assert isinstance(dataset, VisionDataset) + + @test_all_configs + def test_str_smoke(self, config): + with self.create_dataset(config) as (dataset, _): + assert isinstance(str(dataset), str) + + @test_all_configs + def test_feature_types(self, config): + with self.create_dataset(config) as (dataset, _): + + example = dataset[0] + + if len(self.FEATURE_TYPES) > 1: + actual = len(example) + expected = len(self.FEATURE_TYPES) + assert ( + actual == expected + ), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: " + f"{actual} != {expected}" + else: + example = (example,) + + for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)): + with self.subTest(idx=idx): + assert isinstance(feature, expected_feature_type) + + @test_all_configs + def test_num_examples(self, config): + with self.create_dataset(config) as (dataset, info): + assert len(dataset) == info["num_examples"] + + @test_all_configs + def test_transforms(self, config): + mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args) + for kwarg in self._TRANSFORM_KWARGS: + if kwarg not in self._HAS_SPECIAL_KWARG: + continue + + mock.reset_mock() + + with self.subTest(kwarg=kwarg): + with self.create_dataset(config, **{kwarg: mock}) as (dataset, _): + dataset[0] + + mock.assert_called() + + +class ImageDatasetTestCase(DatasetTestCase): + """Abstract base class for image dataset testcases. + + - Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label. + """ + + FEATURE_TYPES = (PIL.Image.Image, int) + + @contextlib.contextmanager + def create_dataset( + self, + config: Optional[Dict[str, Any]] = None, + inject_fake_data: bool = True, + patch_checks: Optional[bool] = None, + **kwargs: Any, + ): + with super().create_dataset( + config=config, + inject_fake_data=inject_fake_data, + patch_checks=patch_checks, + **kwargs, + ) as (dataset, info): + # PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access + # to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we + # force-load opened images. + # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an + # image, but never use the underlying data. During normal operation it is reasonable to assume that the + # user wants to work with the image he just opened rather than deleting the underlying file. + with self._force_load_images(): + yield dataset, info + + @contextlib.contextmanager + def _force_load_images(self): + open = PIL.Image.open + + def new(fp, *args, **kwargs): + image = open(fp, *args, **kwargs) + if isinstance(fp, (str, pathlib.Path)): + image.load() + return image + + with unittest.mock.patch("PIL.Image.open", new=new): + yield + + +class VideoDatasetTestCase(DatasetTestCase): + """Abstract base class for video dataset testcases. + + - Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as + well as an integer label. + - Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``). + - Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()' + and it is the last parameter without a default value in the dataset constructor, the value of the + 'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output. + """ + + FEATURE_TYPES = (torch.Tensor, torch.Tensor, int) + REQUIRED_PACKAGES = ("av",) + + DEFAULT_FRAMES_PER_CLIP = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_args = self._set_default_frames_per_clip(self.dataset_args) + + def _set_default_frames_per_clip(self, inject_fake_data): + argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) + args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)] + frames_per_clip_last = args_without_default[-1] == "frames_per_clip" + + @functools.wraps(inject_fake_data) + def wrapper(tmpdir, config): + args = inject_fake_data(tmpdir, config) + if frames_per_clip_last and len(args) == len(args_without_default) - 1: + args = (*args, self.DEFAULT_FRAMES_PER_CLIP) + + return args + + return wrapper + + +def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: + r"""Create a random uint8 tensor. + + Args: + size (Sequence[int]): Size of the tensor. + """ + return torch.randint(0, 256, size, dtype=torch.uint8) + + +def create_image_file( + root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any +) -> pathlib.Path: + """Create an image file from random data. + + Args: + root (Union[str, pathlib.Path]): Root directory the image file will be placed in. + name (Union[str, pathlib.Path]): Name of the image file. + size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If + scalar, the value is used for the height and width. If not provided, three channels are assumed. + kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`. + + Returns: + pathlib.Path: Path to the created image file. + """ + if isinstance(size, int): + size = (size, size) + if len(size) == 2: + size = (3, *size) + if len(size) != 3: + raise UsageError( + f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead" + ) + + image = create_image_or_video_tensor(size) + file = pathlib.Path(root) / name + + # torch (num_channels x height x width) -> PIL (width x height x num_channels) + image = image.permute(2, 1, 0) + # For grayscale images PIL doesn't use a channel dimension + if image.shape[2] == 1: + image = torch.squeeze(image, 2) + PIL.Image.fromarray(image.numpy()).save(file, **kwargs) + return file + + +def create_image_folder( + root: Union[pathlib.Path, str], + name: Union[pathlib.Path, str], + file_name_fn: Callable[[int], str], + num_examples: int, + size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None, + **kwargs: Any, +) -> List[pathlib.Path]: + """Create a folder of random images. + + Args: + root (Union[str, pathlib.Path]): Root directory the image folder will be placed in. + name (Union[str, pathlib.Path]): Name of the image folder. + file_name_fn (Callable[[int], str]): Should return a file name if called with the file index. + num_examples (int): Number of images to create. + size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If + callable, will be called with the index of the corresponding file. If omitted, a random height and width + between 3 and 10 pixels is selected on a per-image basis. + kwargs (Any): Additional parameters passed to :func:`create_image_file`. + + Returns: + List[pathlib.Path]: Paths to all created image files. + + .. seealso:: + + - :func:`create_image_file` + """ + if size is None: + + def size(idx: int) -> Tuple[int, int, int]: + num_channels = 3 + height, width = torch.randint(3, 11, size=(2,), dtype=torch.int32).tolist() + return (num_channels, height, width) + + root = pathlib.Path(root) / name + os.makedirs(root, exist_ok=True) + + return [ + create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs) + for idx in range(num_examples) + ] + + +def shape_test_for_stereo( + left: PIL.Image.Image, + right: PIL.Image.Image, + disparity: Optional[np.ndarray] = None, + valid_mask: Optional[np.ndarray] = None, +): + left_dims = get_dimensions(left) + right_dims = get_dimensions(right) + c, h, w = left_dims + # check that left and right are the same size + assert left_dims == right_dims + assert c == 3 + + # check that the disparity has the same spatial dimensions + # as the input + if disparity is not None: + assert disparity.ndim == 3 + assert disparity.shape == (1, h, w) + + if valid_mask is not None: + # check that valid mask is the same size as the disparity + _, dh, dw = disparity.shape + mh, mw = valid_mask.shape + assert dh == mh + assert dw == mw + + +@requires_lazy_imports("av") +def create_video_file( + root: Union[pathlib.Path, str], + name: Union[pathlib.Path, str], + size: Union[Sequence[int], int] = (1, 3, 10, 10), + fps: float = 25, + **kwargs: Any, +) -> pathlib.Path: + """Create an video file from random data. + + Args: + root (Union[str, pathlib.Path]): Root directory the video file will be placed in. + name (Union[str, pathlib.Path]): Name of the video file. + size (Union[Sequence[int], int]): Size of the video that represents the + ``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width. + If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed. + fps (float): Frame rate in frames per second. + kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`. + + Returns: + pathlib.Path: Path to the created image file. + + Raises: + UsageError: If PyAV is not available. + """ + if isinstance(size, int): + size = (size, size) + if len(size) == 2: + size = (3, *size) + if len(size) == 3: + size = (1, *size) + if len(size) != 4: + raise UsageError( + f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead" + ) + + video = create_image_or_video_tensor(size) + file = pathlib.Path(root) / name + write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs) + return file + + +@requires_lazy_imports("av") +def create_video_folder( + root: Union[str, pathlib.Path], + name: Union[str, pathlib.Path], + file_name_fn: Callable[[int], str], + num_examples: int, + size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None, + fps=25, + **kwargs, +) -> List[pathlib.Path]: + """Create a folder of random videos. + + Args: + root (Union[str, pathlib.Path]): Root directory the video folder will be placed in. + name (Union[str, pathlib.Path]): Name of the video folder. + file_name_fn (Callable[[int], str]): Should return a file name if called with the file index. + num_examples (int): Number of videos to create. + size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If + callable, will be called with the index of the corresponding file. If omitted, a random even height and + width between 4 and 10 pixels is selected on a per-video basis. + fps (float): Frame rate in frames per second. + kwargs (Any): Additional parameters passed to :func:`create_video_file`. + + Returns: + List[pathlib.Path]: Paths to all created video files. + + Raises: + UsageError: If PyAV is not available. + + .. seealso:: + + - :func:`create_video_file` + """ + if size is None: + + def size(idx): + num_frames = 1 + num_channels = 3 + # The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and + # width of the video to be divisible by 2. + height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int32) * 2).tolist() + return (num_frames, num_channels, height, width) + + root = pathlib.Path(root) / name + os.makedirs(root, exist_ok=True) + + return [ + create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs) + for idx in range(num_examples) + ] + + +def _split_files_or_dirs(root, *files_or_dirs): + files = set() + dirs = set() + for file_or_dir in files_or_dirs: + path = pathlib.Path(file_or_dir) + if not path.is_absolute(): + path = root / path + if path.is_file(): + files.add(path) + else: + dirs.add(path) + for sub_file_or_dir in path.glob("**/*"): + if sub_file_or_dir.is_file(): + files.add(sub_file_or_dir) + else: + dirs.add(sub_file_or_dir) + + if root in dirs: + dirs.remove(root) + + return files, dirs + + +def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True): + archive = pathlib.Path(root) / name + if not files_or_dirs: + # We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are + # present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`. + file_or_dir = archive + for _ in range(len(archive.suffixes)): + file_or_dir = file_or_dir.with_suffix("") + if file_or_dir.exists(): + files_or_dirs = (file_or_dir,) + else: + raise ValueError("No file or dir provided.") + + files, dirs = _split_files_or_dirs(root, *files_or_dirs) + + with opener(archive) as fh: + for file in sorted(files): + adder(fh, file, file.relative_to(root)) + + if remove: + for file in files: + os.remove(file) + for dir in dirs: + shutil.rmtree(dir, ignore_errors=True) + + return archive + + +def make_tar(root, name, *files_or_dirs, remove=True, compression=None): + # TODO: detect compression from name + return _make_archive( + root, + name, + *files_or_dirs, + opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"), + adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file), + remove=remove, + ) + + +def make_zip(root, name, *files_or_dirs, remove=True): + return _make_archive( + root, + name, + *files_or_dirs, + opener=lambda archive: zipfile.ZipFile(archive, "w"), + adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file), + remove=remove, + ) + + +def create_random_string(length: int, *digits: str) -> str: + """Create a random string. + + Args: + length (int): Number of characters in the generated string. + *characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`. + """ + if not digits: + digits = string.ascii_lowercase + else: + digits = "".join(itertools.chain(*digits)) + + return "".join(random.choice(digits) for _ in range(length)) + + +def make_fake_pfm_file(h, w, file_name): + values = list(range(3 * h * w)) + # Note: we pack everything in little endian: -1.0, and "<" + content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values) + with open(file_name, "wb") as f: + f.write(content) + + +def make_fake_flo_file(h, w, file_name): + """Creates a fake flow file in .flo format.""" + # Everything needs to be in little Endian according to + # https://vision.middlebury.edu/flow/code/flow-code/README.txt + values = list(range(2 * h * w)) + content = ( + struct.pack("<4c", *(c.encode() for c in "PIEH")) + + struct.pack("") + for image in (None, *images): + self._add_image(fh, image, num_captions_per_image) + fh.write("") + + def _add_image(self, fh, image, num_captions_per_image): + fh.write("") + self._add_image_header(fh, image) + fh.write("
    ") + self._add_image_captions(fh, num_captions_per_image) + fh.write("
") + + def _add_image_header(self, fh, image=None): + if image: + url = f"http://www.flickr.com/photos/user/{image.name.split('_')[0]}/" + data = f'{url}' + else: + data = "Image Not Found" + fh.write(f"{data}") + + def _add_image_captions(self, fh, num_captions_per_image): + for caption in self._create_captions(num_captions_per_image): + fh.write(f"
  • {caption}") + + def _create_captions(self, num_captions_per_image): + return [str(idx) for idx in range(num_captions_per_image)] + + def test_captions(self): + with self.create_dataset() as (dataset, info): + _, captions = dataset[0] + assert len(captions) == len(info["captions"]) + assert all([a == b for a, b in zip(captions, info["captions"])]) + + +class Flickr30kTestCase(Flickr8kTestCase): + DATASET_CLASS = datasets.Flickr30k + + FEATURE_TYPES = (PIL.Image.Image, list) + + _ANNOTATIONS_FILE = "captions.token" + + def _image_file_name(self, idx): + return f"{idx}.jpg" + + def _create_annotations_file(self, root, name, images, num_captions_per_image): + with open(root / name, "w") as fh: + for image, (idx, caption) in itertools.product( + images, enumerate(self._create_captions(num_captions_per_image)) + ): + fh.write(f"{image.name}#{idx}\t{caption}\n") + + +class MNISTTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.MNIST + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) + + _MAGIC_DTYPES = { + torch.uint8: 8, + torch.int8: 9, + torch.int16: 11, + torch.int32: 12, + torch.float32: 13, + torch.float64: 14, + } + + _IMAGES_SIZE = (28, 28) + _IMAGES_DTYPE = torch.uint8 + + _LABELS_SIZE = () + _LABELS_DTYPE = torch.uint8 + + def inject_fake_data(self, tmpdir, config): + raw_dir = pathlib.Path(tmpdir) / self.DATASET_CLASS.__name__ / "raw" + os.makedirs(raw_dir, exist_ok=True) + + num_images = self._num_images(config) + self._create_binary_file( + raw_dir, self._images_file(config), (num_images, *self._IMAGES_SIZE), self._IMAGES_DTYPE + ) + self._create_binary_file( + raw_dir, self._labels_file(config), (num_images, *self._LABELS_SIZE), self._LABELS_DTYPE + ) + return num_images + + def _num_images(self, config): + return 2 if config["train"] else 1 + + def _images_file(self, config): + return f"{self._prefix(config)}-images-idx3-ubyte" + + def _labels_file(self, config): + return f"{self._prefix(config)}-labels-idx1-ubyte" + + def _prefix(self, config): + return "train" if config["train"] else "t10k" + + def _create_binary_file(self, root, filename, size, dtype): + with open(pathlib.Path(root) / filename, "wb") as fh: + for meta in (self._magic(dtype, len(size)), *size): + fh.write(self._encode(meta)) + + # If ever an MNIST variant is added that uses floating point data, this should be adapted. + data = torch.randint(0, torch.iinfo(dtype).max + 1, size, dtype=dtype) + fh.write(data.numpy().tobytes()) + + def _magic(self, dtype, dims): + return self._MAGIC_DTYPES[dtype] * 256 + dims + + def _encode(self, v): + return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1] + + +class FashionMNISTTestCase(MNISTTestCase): + DATASET_CLASS = datasets.FashionMNIST + + +class KMNISTTestCase(MNISTTestCase): + DATASET_CLASS = datasets.KMNIST + + +class EMNISTTestCase(MNISTTestCase): + DATASET_CLASS = datasets.EMNIST + + DEFAULT_CONFIG = dict(split="byclass") + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False) + ) + + def _prefix(self, config): + return f"emnist-{config['split']}-{'train' if config['train'] else 'test'}" + + +class QMNISTTestCase(MNISTTestCase): + DATASET_CLASS = datasets.QMNIST + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist")) + + _LABELS_SIZE = (8,) + _LABELS_DTYPE = torch.int32 + + def _num_images(self, config): + if config["what"] == "nist": + return 3 + elif config["what"] == "train": + return 2 + elif config["what"] == "test50k": + # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create + # more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the + # creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in + # 'test_num_examples_test50k'. + return 10001 + else: + return 1 + + def _labels_file(self, config): + return f"{self._prefix(config)}-labels-idx2-int" + + def _prefix(self, config): + if config["what"] == "nist": + return "xnist" + + if config["what"] is None: + what = "train" if config["train"] else "test" + elif config["what"].startswith("test"): + what = "test" + else: + what = config["what"] + + return f"qmnist-{what}" + + def test_num_examples_test50k(self): + with self.create_dataset(what="test50k") as (dataset, info): + # Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of + # created examples by this. + assert len(dataset) == info["num_examples"] - 10000 + + +class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.DatasetFolder + + # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader + # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method. + FEATURE_TYPES = (str, int) + + _IMAGE_EXTENSIONS = ("jpg", "png") + _VIDEO_EXTENSIONS = ("avi", "mp4") + _EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS) + + # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. + # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the + # 'test_is_valid_file()' method. + DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) + ADDITIONAL_CONFIGS = ( + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]), + dict(extensions=_IMAGE_EXTENSIONS), + *datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]), + dict(extensions=_VIDEO_EXTENSIONS), + ) + + def dataset_args(self, tmpdir, config): + return tmpdir, lambda x: x + + def inject_fake_data(self, tmpdir, config): + extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) + + num_examples_total = 0 + classes = [] + for ext, cls in zip(self._EXTENSIONS, string.ascii_letters): + if ext not in extensions: + continue + + create_example_folder = ( + datasets_utils.create_image_folder + if ext in self._IMAGE_EXTENSIONS + else datasets_utils.create_video_folder + ) + + num_examples = torch.randint(1, 3, size=()).item() + create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) + + num_examples_total += num_examples + classes.append(cls) + + return dict(num_examples=num_examples_total, classes=classes) + + def _file_name_fn(self, cls, ext, idx): + return f"{cls}_{idx}.{ext}" + + def _is_valid_file_to_extensions(self, is_valid_file): + return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")} + + @datasets_utils.test_all_configs + def test_is_valid_file(self, config): + extensions = config.pop("extensions") + # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the + # DEFAULT_CONFIG. + with self.create_dataset( + config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + ) as (dataset, info): + assert len(dataset) == info["num_examples"] + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + assert len(dataset.classes) == len(info["classes"]) + assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + + +class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ImageFolder + + def inject_fake_data(self, tmpdir, config): + num_examples_total = 0 + classes = ("a", "b") + for cls in classes: + num_examples = torch.randint(1, 3, size=()).item() + num_examples_total += num_examples + + datasets_utils.create_image_folder(tmpdir, cls, lambda idx: f"{cls}_{idx}.png", num_examples) + + return dict(num_examples=num_examples_total, classes=classes) + + @datasets_utils.test_all_configs + def test_classes(self, config): + with self.create_dataset(config) as (dataset, info): + assert len(dataset.classes) == len(info["classes"]) + assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + + +class KittiTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Kitti + FEATURE_TYPES = (PIL.Image.Image, (list, type(None))) # test split returns None as target + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) + + def inject_fake_data(self, tmpdir, config): + kitti_dir = os.path.join(tmpdir, "Kitti", "raw") + os.makedirs(kitti_dir) + + split_to_num_examples = { + True: 1, + False: 2, + } + + # We need to create all folders(training and testing). + for is_training in (True, False): + num_examples = split_to_num_examples[is_training] + + datasets_utils.create_image_folder( + root=kitti_dir, + name=os.path.join("training" if is_training else "testing", "image_2"), + file_name_fn=lambda image_idx: f"{image_idx:06d}.png", + num_examples=num_examples, + ) + if is_training: + for image_idx in range(num_examples): + target_file_dir = os.path.join(kitti_dir, "training", "label_2") + os.makedirs(target_file_dir) + target_file_name = os.path.join(target_file_dir, f"{image_idx:06d}.txt") + target_contents = "Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01\n" # noqa + with open(target_file_name, "w") as target_file: + target_file.write(target_contents) + + return split_to_num_examples[config["train"]] + + +class SvhnTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SVHN + REQUIRED_PACKAGES = ("scipy",) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "extra")) + + def inject_fake_data(self, tmpdir, config): + import scipy.io as sio + + split = config["split"] + num_examples = { + "train": 2, + "test": 3, + "extra": 4, + }.get(split) + + file = f"{split}_32x32.mat" + images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8) + targets = np.zeros((num_examples,), dtype=np.uint8) + sio.savemat(os.path.join(tmpdir, file), {"X": images, "y": targets}) + return num_examples + + +class Places365TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Places365 + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train-standard", "train-challenge", "val"), + small=(False, True), + ) + _CATEGORIES = "categories_places365.txt" + # {split: file} + _FILE_LISTS = { + "train-standard": "places365_train_standard.txt", + "train-challenge": "places365_train_challenge.txt", + "val": "places365_val.txt", + } + # {(split, small): folder_name} + _IMAGES = { + ("train-standard", False): "data_large_standard", + ("train-challenge", False): "data_large_challenge", + ("val", False): "val_large", + ("train-standard", True): "data_256_standard", + ("train-challenge", True): "data_256_challenge", + ("val", True): "val_256", + } + # (class, idx) + _CATEGORIES_CONTENT = ( + ("/a/airfield", 0), + ("/a/apartment_building/outdoor", 8), + ("/b/badlands", 30), + ) + # (file, idx) + _FILE_LIST_CONTENT = ( + ("Places365_val_00000001.png", 0), + *((f"{category}/Places365_train_00000001.png", idx) for category, idx in _CATEGORIES_CONTENT), + ) + + @staticmethod + def _make_txt(root, name, seq): + file = os.path.join(root, name) + with open(file, "w") as fh: + for text, idx in seq: + fh.write(f"{text} {idx}\n") + + @staticmethod + def _make_categories_txt(root, name): + Places365TestCase._make_txt(root, name, Places365TestCase._CATEGORIES_CONTENT) + + @staticmethod + def _make_file_list_txt(root, name): + Places365TestCase._make_txt(root, name, Places365TestCase._FILE_LIST_CONTENT) + + @staticmethod + def _make_image(file_name, size): + os.makedirs(os.path.dirname(file_name), exist_ok=True) + PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file_name) + + @staticmethod + def _make_devkit_archive(root, split): + Places365TestCase._make_categories_txt(root, Places365TestCase._CATEGORIES) + Places365TestCase._make_file_list_txt(root, Places365TestCase._FILE_LISTS[split]) + + @staticmethod + def _make_images_archive(root, split, small): + folder_name = Places365TestCase._IMAGES[(split, small)] + image_size = (256, 256) if small else (512, random.randint(512, 1024)) + files, idcs = zip(*Places365TestCase._FILE_LIST_CONTENT) + images = [f.lstrip("/").replace("/", os.sep) for f in files] + for image in images: + Places365TestCase._make_image(os.path.join(root, folder_name, image), image_size) + + return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] + + def inject_fake_data(self, tmpdir, config): + self._make_devkit_archive(tmpdir, config["split"]) + return len(self._make_images_archive(tmpdir, config["split"], config["small"])) + + def test_classes(self): + classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) + with self.create_dataset() as (dataset, _): + assert dataset.classes == classes + + def test_class_to_idx(self): + class_to_idx = dict(self._CATEGORIES_CONTENT) + with self.create_dataset() as (dataset, _): + assert dataset.class_to_idx == class_to_idx + + def test_images_download_preexisting(self): + with pytest.raises(RuntimeError): + with self.create_dataset({"download": True}): + pass + + +class INaturalistTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.INaturalist + FEATURE_TYPES = (PIL.Image.Image, (int, tuple)) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + target_type=("kingdom", "full", "genus", ["kingdom", "phylum", "class", "order", "family", "genus", "full"]), + version=("2021_train",), + ) + + def inject_fake_data(self, tmpdir, config): + categories = [ + "00000_Akingdom_0phylum_Aclass_Aorder_Afamily_Agenus_Aspecies", + "00001_Akingdom_1phylum_Aclass_Border_Afamily_Bgenus_Aspecies", + "00002_Akingdom_2phylum_Cclass_Corder_Cfamily_Cgenus_Cspecies", + ] + + num_images_per_category = 3 + for category in categories: + datasets_utils.create_image_folder( + root=os.path.join(tmpdir, config["version"]), + name=category, + file_name_fn=lambda idx: f"image_{idx + 1:04d}.jpg", + num_examples=num_images_per_category, + ) + + return num_images_per_category * len(categories) + + def test_targets(self): + target_types = ["kingdom", "phylum", "class", "order", "family", "genus", "full"] + + with self.create_dataset(target_type=target_types, version="2021_valid") as (dataset, _): + items = [d[1] for d in dataset] + for i, item in enumerate(items): + assert dataset.category_name("kingdom", item[0]) == "Akingdom" + assert dataset.category_name("phylum", item[1]) == f"{i // 3}phylum" + assert item[6] == i // 3 + + +class LFWPeopleTestCase(datasets_utils.DatasetTestCase): + DATASET_CLASS = datasets.LFWPeople + FEATURE_TYPES = (PIL.Image.Image, int) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled") + ) + _IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"} + _file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"} + + def inject_fake_data(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) / "lfw-py" + os.makedirs(tmpdir, exist_ok=True) + return dict( + num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]), + split=config["split"], + ) + + def _create_images_dir(self, root, idir, split): + idir = os.path.join(root, idir) + os.makedirs(idir, exist_ok=True) + n, flines = (10, ["10\n"]) if split == "10fold" else (1, []) + num_examples = 0 + names = [] + for _ in range(n): + num_people = random.randint(2, 5) + flines.append(f"{num_people}\n") + for i in range(num_people): + name = self._create_random_id() + no = random.randint(1, 10) + flines.append(f"{name}\t{no}\n") + names.append(f"{name}\t{no}\n") + datasets_utils.create_image_folder(idir, name, lambda n: f"{name}_{n+1:04d}.jpg", no, 250) + num_examples += no + with open(pathlib.Path(root) / f"people{self._file_id[split]}.txt", "w") as f: + f.writelines(flines) + with open(pathlib.Path(root) / "lfw-names.txt", "w") as f: + f.writelines(sorted(names)) + + return num_examples + + def _create_random_id(self): + part1 = datasets_utils.create_random_string(random.randint(5, 7)) + part2 = datasets_utils.create_random_string(random.randint(4, 7)) + return f"{part1}_{part2}" + + +class LFWPairsTestCase(LFWPeopleTestCase): + DATASET_CLASS = datasets.LFWPairs + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, int) + + def _create_images_dir(self, root, idir, split): + idir = os.path.join(root, idir) + os.makedirs(idir, exist_ok=True) + num_pairs = 7 # effectively 7*2*n = 14*n + n, self.flines = (10, [f"10\t{num_pairs}"]) if split == "10fold" else (1, [str(num_pairs)]) + for _ in range(n): + self._inject_pairs(idir, num_pairs, True) + self._inject_pairs(idir, num_pairs, False) + with open(pathlib.Path(root) / f"pairs{self._file_id[split]}.txt", "w") as f: + f.writelines(self.flines) + + return num_pairs * 2 * n + + def _inject_pairs(self, root, num_pairs, same): + for i in range(num_pairs): + name1 = self._create_random_id() + name2 = name1 if same else self._create_random_id() + no1, no2 = random.randint(1, 100), random.randint(1, 100) + if same: + self.flines.append(f"\n{name1}\t{no1}\t{no2}") + else: + self.flines.append(f"\n{name1}\t{no1}\t{name2}\t{no2}") + + datasets_utils.create_image_folder(root, name1, lambda _: f"{name1}_{no1:04d}.jpg", 1, 250) + datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250) + + +class SintelTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Sintel + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "Sintel" + + num_images_per_scene = 3 if config["split"] == "train" else 4 + num_scenes = 2 + + for split_dir in ("training", "test"): + for pass_name in ("clean", "final"): + image_root = root / split_dir / pass_name + + for scene_id in range(num_scenes): + scene_dir = image_root / f"scene_{scene_id}" + datasets_utils.create_image_folder( + image_root, + name=str(scene_dir), + file_name_fn=lambda image_idx: f"frame_000{image_idx}.png", + num_examples=num_images_per_scene, + ) + + flow_root = root / "training" / "flow" + for scene_id in range(num_scenes): + scene_dir = flow_root / f"scene_{scene_id}" + os.makedirs(scene_dir) + for i in range(num_images_per_scene - 1): + file_name = str(scene_dir / f"frame_000{i}.flo") + datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name) + + # with e.g. num_images_per_scene = 3, for a single scene with have 3 images + # which are frame_0000, frame_0001 and frame_0002 + # They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002), + # that is 3 - 1 = 2 examples. Hence the formula below + num_passes = 2 if config["pass_name"] == "both" else 1 + num_examples = (num_images_per_scene - 1) * num_scenes * num_passes + return num_examples + + def test_flow(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + h, w = self.FLOW_H, self.FLOW_W + expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1) + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow.shape == (2, h, w) + np.testing.assert_allclose(flow, expected_flow) + + # Make sure flow is always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow in dataset: + assert flow is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"): + with self.create_dataset(pass_name="bad"): + pass + + +class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.KittiFlow + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "KittiFlow" + + num_examples = 2 if config["split"] == "train" else 3 + for split_dir in ("training", "testing"): + + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + datasets_utils.create_image_folder( + root / split_dir, + name="image_2", + file_name_fn=lambda image_idx: f"{image_idx}_11.png", + num_examples=num_examples, + ) + + # For kitti the ground truth flows are encoded as 16-bits pngs. + # create_image_folder() will actually create 8-bits pngs, but it doesn't + # matter much: the flow reader will still be able to read the files, it + # will just be garbage flow value - but we don't care about that here. + datasets_utils.create_image_folder( + root / "training", + name="flow_occ", + file_name_fn=lambda image_idx: f"{image_idx}_10.png", + num_examples=num_examples, + ) + + return num_examples + + def test_flow_and_valid(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + # Also assert flow and valid are of the expected shape + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow, valid in dataset: + two, h, w = flow.shape + assert two == 2 + assert valid.shape == (h, w) + + # Make sure flow and valid are always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow, valid in dataset: + assert flow is None + assert valid is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + +class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FlyingChairs + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 + + def _make_split_file(self, root, num_examples): + # We create a fake split file here, but users are asked to download the real one from the authors website + split_ids = [1] * num_examples["train"] + [2] * num_examples["val"] + random.shuffle(split_ids) + with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file: + for split_id in split_ids: + split_file.write(f"{split_id}\n") + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "FlyingChairs" + + num_examples = {"train": 5, "val": 3} + num_examples_total = sum(num_examples.values()) + + datasets_utils.create_image_folder( # img1 + root, + name="data", + file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm", + num_examples=num_examples_total, + ) + datasets_utils.create_image_folder( # img2 + root, + name="data", + file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm", + num_examples=num_examples_total, + ) + for i in range(num_examples_total): + file_name = str(root / "data" / f"00{i}_flow.flo") + datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name) + + self._make_split_file(root, num_examples) + + return num_examples[config["split"]] + + @datasets_utils.test_all_configs + def test_flow(self, config): + # Make sure flow always exists, and make sure there are as many flow values as (pairs of) images + # Also make sure the flow is properly decoded + + h, w = self.FLOW_H, self.FLOW_W + expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1) + with self.create_dataset(config=config) as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow.shape == (2, h, w) + np.testing.assert_allclose(flow, expected_flow) + + +class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FlyingThings3D + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both") + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "FlyingThings3D" + + num_images_per_camera = 3 if config["split"] == "train" else 4 + passes = ("frames_cleanpass", "frames_finalpass") + splits = ("TRAIN", "TEST") + letters = ("A", "B", "C") + subfolders = ("0000", "0001") + cameras = ("left", "right") + for pass_name, split, letter, subfolder, camera in itertools.product( + passes, splits, letters, subfolders, cameras + ): + current_folder = root / pass_name / split / letter / subfolder + datasets_utils.create_image_folder( + current_folder, + name=camera, + file_name_fn=lambda image_idx: f"00{image_idx}.png", + num_examples=num_images_per_camera, + ) + + directions = ("into_future", "into_past") + for split, letter, subfolder, direction, camera in itertools.product( + splits, letters, subfolders, directions, cameras + ): + current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera + os.makedirs(str(current_folder), exist_ok=True) + for i in range(num_images_per_camera): + datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm")) + + num_cameras = 2 if config["camera"] == "both" else 1 + num_passes = 2 if config["pass_name"] == "both" else 1 + num_examples = ( + (num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes + ) + return num_examples + + @datasets_utils.test_all_configs + def test_flow(self, config): + h, w = self.FLOW_H, self.FLOW_W + expected_flow = np.arange(3 * h * w).reshape(h, w, 3).transpose(2, 0, 1) + expected_flow = np.flip(expected_flow, axis=1) + expected_flow = expected_flow[:2, :, :] + + with self.create_dataset(config=config) as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow.shape == (2, self.FLOW_H, self.FLOW_W) + np.testing.assert_allclose(flow, expected_flow) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"): + with self.create_dataset(pass_name="bad"): + pass + + with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"): + with self.create_dataset(camera="bad"): + pass + + +class HD1KTestCase(KittiFlowTestCase): + DATASET_CLASS = datasets.HD1K + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "hd1k" + + num_sequences = 4 if config["split"] == "train" else 3 + num_examples_per_train_sequence = 3 + + for seq_idx in range(num_sequences): + # Training data + datasets_utils.create_image_folder( + root / "hd1k_input", + name="image_2", + file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png", + num_examples=num_examples_per_train_sequence, + ) + datasets_utils.create_image_folder( + root / "hd1k_flow_gt", + name="flow_occ", + file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png", + num_examples=num_examples_per_train_sequence, + ) + + # Test data + datasets_utils.create_image_folder( + root / "hd1k_challenge", + name="image_2", + file_name_fn=lambda _: f"{seq_idx:06d}_10.png", + num_examples=1, + ) + datasets_utils.create_image_folder( + root / "hd1k_challenge", + name="image_2", + file_name_fn=lambda _: f"{seq_idx:06d}_11.png", + num_examples=1, + ) + + num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2 + return num_sequences * (num_examples_per_sequence - 1) + + +class EuroSATTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.EuroSAT + FEATURE_TYPES = (PIL.Image.Image, int) + + def inject_fake_data(self, tmpdir, config): + data_folder = os.path.join(tmpdir, "eurosat", "2750") + os.makedirs(data_folder) + + num_examples_per_class = 3 + classes = ("AnnualCrop", "Forest") + for cls in classes: + datasets_utils.create_image_folder( + root=data_folder, + name=cls, + file_name_fn=lambda idx: f"{cls}_{idx}.jpg", + num_examples=num_examples_per_class, + ) + + return len(classes) * num_examples_per_class + + +class Food101TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Food101 + FEATURE_TYPES = (PIL.Image.Image, int) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + def inject_fake_data(self, tmpdir: str, config): + root_folder = pathlib.Path(tmpdir) / "food-101" + image_folder = root_folder / "images" + meta_folder = root_folder / "meta" + + image_folder.mkdir(parents=True) + meta_folder.mkdir() + + num_images_per_class = 5 + + metadata = {} + n_samples_per_class = 3 if config["split"] == "train" else 2 + sampled_classes = ("apple_pie", "crab_cakes", "gyoza") + for cls in sampled_classes: + im_fnames = datasets_utils.create_image_folder( + image_folder, + cls, + file_name_fn=lambda idx: f"{idx}.jpg", + num_examples=num_images_per_class, + ) + metadata[cls] = [ + "/".join(fname.relative_to(image_folder).with_suffix("").parts) + for fname in random.choices(im_fnames, k=n_samples_per_class) + ] + + with open(meta_folder / f"{config['split']}.json", "w") as file: + file.write(json.dumps(metadata)) + + return len(sampled_classes * n_samples_per_class) + + +class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FGVCAircraft + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer") + ) + + def inject_fake_data(self, tmpdir: str, config): + split = config["split"] + annotation_level = config["annotation_level"] + annotation_level_to_file = { + "variant": "variants.txt", + "family": "families.txt", + "manufacturer": "manufacturers.txt", + } + + root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b" + data_folder = root_folder / "data" + + classes = ["707-320", "Hawk T1", "Tornado"] + num_images_per_class = 5 + + datasets_utils.create_image_folder( + data_folder, + "images", + file_name_fn=lambda idx: f"{idx}.jpg", + num_examples=num_images_per_class * len(classes), + ) + + annotation_file = data_folder / annotation_level_to_file[annotation_level] + with open(annotation_file, "w") as file: + file.write("\n".join(classes)) + + num_samples_per_class = 4 if split == "trainval" else 2 + images_classes = [] + for i in range(len(classes)): + images_classes.extend( + [ + f"{idx} {classes[i]}" + for idx in random.sample( + range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class + ) + ] + ) + + images_annotation_file = data_folder / f"images_{annotation_level}_{split}.txt" + with open(images_annotation_file, "w") as file: + file.write("\n".join(images_classes)) + + return len(classes * num_samples_per_class) + + +class SUN397TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.SUN397 + + def inject_fake_data(self, tmpdir: str, config): + data_dir = pathlib.Path(tmpdir) / "SUN397" + data_dir.mkdir() + + num_images_per_class = 5 + sampled_classes = ("abbey", "airplane_cabin", "airport_terminal") + im_paths = [] + + for cls in sampled_classes: + image_folder = data_dir / cls[0] + im_paths.extend( + datasets_utils.create_image_folder( + image_folder, + image_folder / cls, + file_name_fn=lambda idx: f"sun_{idx}.jpg", + num_examples=num_images_per_class, + ) + ) + + with open(data_dir / "ClassName.txt", "w") as file: + file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes)) + + num_samples = len(im_paths) + + return num_samples + + +class DTDTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.DTD + FEATURE_TYPES = (PIL.Image.Image, int) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test", "val"), + # There is no need to test the whole matrix here, since each fold is treated exactly the same + partition=(1, 5, 10), + ) + + def inject_fake_data(self, tmpdir: str, config): + data_folder = pathlib.Path(tmpdir) / "dtd" / "dtd" + + num_images_per_class = 3 + image_folder = data_folder / "images" + image_files = [] + for cls in ("banded", "marbled", "zigzagged"): + image_files.extend( + datasets_utils.create_image_folder( + image_folder, + cls, + file_name_fn=lambda idx: f"{cls}_{idx:04d}.jpg", + num_examples=num_images_per_class, + ) + ) + + meta_folder = data_folder / "labels" + meta_folder.mkdir() + image_ids = [str(path.relative_to(path.parents[1])).replace(os.sep, "/") for path in image_files] + image_ids_in_config = random.choices(image_ids, k=len(image_files) // 2) + with open(meta_folder / f"{config['split']}{config['partition']}.txt", "w") as file: + file.write("\n".join(image_ids_in_config) + "\n") + + return len(image_ids_in_config) + + +class FER2013TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FER2013 + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) + + def inject_fake_data(self, tmpdir, config): + base_folder = os.path.join(tmpdir, "fer2013") + os.makedirs(base_folder) + + num_samples = 5 + with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file: + writer = csv.DictWriter( + file, + fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",), + quoting=csv.QUOTE_NONNUMERIC, + quotechar='"', + ) + writer.writeheader() + for _ in range(num_samples): + row = dict( + pixels=" ".join( + str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist() + ) + ) + if config["split"] == "train": + row["emotion"] = str(int(torch.randint(0, 7, ()))) + + writer.writerow(row) + + return num_samples + + +class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.GTSRB + FEATURE_TYPES = (PIL.Image.Image, int) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + def inject_fake_data(self, tmpdir: str, config): + root_folder = os.path.join(tmpdir, "gtsrb") + os.makedirs(root_folder, exist_ok=True) + + # Train data + train_folder = os.path.join(root_folder, "GTSRB", "Training") + os.makedirs(train_folder, exist_ok=True) + + num_examples = 3 if config["split"] == "train" else 4 + classes = ("00000", "00042", "00012") + for class_idx in classes: + datasets_utils.create_image_folder( + train_folder, + name=class_idx, + file_name_fn=lambda image_idx: f"{class_idx}_{image_idx:05d}.ppm", + num_examples=num_examples, + ) + + total_number_of_examples = num_examples * len(classes) + # Test data + test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images") + os.makedirs(test_folder, exist_ok=True) + + with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file: + csv_file.write("Filename;Width;Height;Roi.X1;Roi.Y1;Roi.X2;Roi.Y2;ClassId\n") + + for _ in range(total_number_of_examples): + image_file = datasets_utils.create_random_string(5, string.digits) + ".ppm" + datasets_utils.create_image_file(test_folder, image_file) + row = [ + image_file, + torch.randint(1, 100, size=()).item(), + torch.randint(1, 100, size=()).item(), + torch.randint(1, 100, size=()).item(), + torch.randint(1, 100, size=()).item(), + torch.randint(1, 100, size=()).item(), + torch.randint(1, 100, size=()).item(), + torch.randint(0, 43, size=()).item(), + ] + csv_file.write(";".join(map(str, row)) + "\n") + + return total_number_of_examples + + +class CLEVRClassificationTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.CLEVRClassification + FEATURE_TYPES = (PIL.Image.Image, (int, type(None))) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + + def inject_fake_data(self, tmpdir, config): + data_folder = pathlib.Path(tmpdir) / "clevr" / "CLEVR_v1.0" + + images_folder = data_folder / "images" + image_files = datasets_utils.create_image_folder( + images_folder, config["split"], lambda idx: f"CLEVR_{config['split']}_{idx:06d}.png", num_examples=5 + ) + + scenes_folder = data_folder / "scenes" + scenes_folder.mkdir() + if config["split"] != "test": + with open(scenes_folder / f"CLEVR_{config['split']}_scenes.json", "w") as file: + json.dump( + dict( + info=dict(), + scenes=[ + dict(image_filename=image_file.name, objects=[dict()] * int(torch.randint(0, 10, ()))) + for image_file in image_files + ], + ), + file, + ) + + return len(image_files) + + +class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.OxfordIIITPet + FEATURE_TYPES = (PIL.Image.Image, (int, PIL.Image.Image, tuple, type(None))) + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("trainval", "test"), + target_types=("category", "segmentation", ["category", "segmentation"], []), + ) + + def inject_fake_data(self, tmpdir, config): + base_folder = os.path.join(tmpdir, "oxford-iiit-pet") + + classification_anns_meta = ( + dict(cls="Abyssinian", label=0, species="cat"), + dict(cls="Keeshond", label=18, species="dog"), + dict(cls="Yorkshire Terrier", label=37, species="dog"), + ) + split_and_classification_anns = [ + self._meta_to_split_and_classification_ann(meta, idx) + for meta, idx in itertools.product(classification_anns_meta, (1, 2, 10)) + ] + image_ids, *_ = zip(*split_and_classification_anns) + + image_files = datasets_utils.create_image_folder( + base_folder, "images", file_name_fn=lambda idx: f"{image_ids[idx]}.jpg", num_examples=len(image_ids) + ) + + anns_folder = os.path.join(base_folder, "annotations") + os.makedirs(anns_folder) + split_and_classification_anns_in_split = random.choices(split_and_classification_anns, k=len(image_ids) // 2) + with open(os.path.join(anns_folder, f"{config['split']}.txt"), "w", newline="") as file: + writer = csv.writer(file, delimiter=" ") + for split_and_classification_ann in split_and_classification_anns_in_split: + writer.writerow(split_and_classification_ann) + + segmentation_files = datasets_utils.create_image_folder( + anns_folder, "trimaps", file_name_fn=lambda idx: f"{image_ids[idx]}.png", num_examples=len(image_ids) + ) + + # The dataset has some rogue files + for path in image_files[:2]: + path.with_suffix(".mat").touch() + for path in segmentation_files: + path.with_name(f".{path.name}").touch() + + return len(split_and_classification_anns_in_split) + + def _meta_to_split_and_classification_ann(self, meta, idx): + image_id = "_".join( + [ + *[(str.title if meta["species"] == "cat" else str.lower)(part) for part in meta["cls"].split()], + str(idx), + ] + ) + class_id = str(meta["label"] + 1) + species = "1" if meta["species"] == "cat" else "2" + breed_id = "-1" + return (image_id, class_id, species, breed_id) + + +class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.StanfordCars + REQUIRED_PACKAGES = ("scipy",) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + + def inject_fake_data(self, tmpdir, config): + import scipy.io as io + from numpy.core.records import fromarrays + + num_examples = {"train": 5, "test": 7}[config["split"]] + num_classes = 3 + base_folder = pathlib.Path(tmpdir) / "stanford_cars" + + devkit = base_folder / "devkit" + devkit.mkdir(parents=True) + + if config["split"] == "train": + images_folder_name = "cars_train" + annotations_mat_path = devkit / "cars_train_annos.mat" + else: + images_folder_name = "cars_test" + annotations_mat_path = base_folder / "cars_test_annos_withlabels.mat" + + datasets_utils.create_image_folder( + root=base_folder, + name=images_folder_name, + file_name_fn=lambda image_index: f"{image_index:5d}.jpg", + num_examples=num_examples, + ) + + classes = np.random.randint(1, num_classes + 1, num_examples, dtype=np.uint8) + fnames = [f"{i:5d}.jpg" for i in range(num_examples)] + rec_array = fromarrays( + [classes, fnames], + names=["class", "fname"], + ) + io.savemat(annotations_mat_path, {"annotations": rec_array}) + + random_class_names = ["random_name"] * num_classes + io.savemat(devkit / "cars_meta.mat", {"class_names": random_class_names}) + + return num_examples + + +class Country211TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Country211 + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test")) + + def inject_fake_data(self, tmpdir: str, config): + split_folder = pathlib.Path(tmpdir) / "country211" / config["split"] + split_folder.mkdir(parents=True, exist_ok=True) + + num_examples = { + "train": 3, + "valid": 4, + "test": 5, + }[config["split"]] + + classes = ("AD", "BS", "GR") + for cls in classes: + datasets_utils.create_image_folder( + split_folder, + name=cls, + file_name_fn=lambda idx: f"{idx}.jpg", + num_examples=num_examples, + ) + + return num_examples * len(classes) + + +class Flowers102TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.Flowers102 + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + REQUIRED_PACKAGES = ("scipy",) + + def inject_fake_data(self, tmpdir: str, config): + base_folder = pathlib.Path(tmpdir) / "flowers-102" + + num_classes = 3 + num_images_per_split = dict(train=5, val=4, test=3) + num_images_total = sum(num_images_per_split.values()) + datasets_utils.create_image_folder( + base_folder, + "jpg", + file_name_fn=lambda idx: f"image_{idx + 1:05d}.jpg", + num_examples=num_images_total, + ) + + label_dict = dict( + labels=np.random.randint(1, num_classes + 1, size=(1, num_images_total), dtype=np.uint8), + ) + datasets_utils.lazy_importer.scipy.io.savemat(str(base_folder / "imagelabels.mat"), label_dict) + + setid_mat = np.arange(1, num_images_total + 1, dtype=np.uint16) + np.random.shuffle(setid_mat) + setid_dict = dict( + trnid=setid_mat[: num_images_per_split["train"]].reshape(1, -1), + valid=setid_mat[num_images_per_split["train"] : -num_images_per_split["test"]].reshape(1, -1), + tstid=setid_mat[-num_images_per_split["test"] :].reshape(1, -1), + ) + datasets_utils.lazy_importer.scipy.io.savemat(str(base_folder / "setid.mat"), setid_dict) + + return num_images_per_split[config["split"]] + + +class PCAMTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.PCAM + + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + REQUIRED_PACKAGES = ("h5py",) + + def inject_fake_data(self, tmpdir: str, config): + base_folder = pathlib.Path(tmpdir) / "pcam" + base_folder.mkdir() + + num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] + + images_file = datasets.PCAM._FILES[config["split"]]["images"][0] + with datasets_utils.lazy_importer.h5py.File(str(base_folder / images_file), "w") as f: + f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8) + + targets_file = datasets.PCAM._FILES[config["split"]]["targets"][0] + with datasets_utils.lazy_importer.h5py.File(str(base_folder / targets_file), "w") as f: + f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8) + + return num_images + + +class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.RenderedSST2 + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) + SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"} + + def inject_fake_data(self, tmpdir: str, config): + root_folder = pathlib.Path(tmpdir) / "rendered-sst2" + image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]] + + num_images_per_class = {"train": 5, "test": 6, "val": 7} + sampled_classes = ["positive", "negative"] + for cls in sampled_classes: + datasets_utils.create_image_folder( + image_folder, + cls, + file_name_fn=lambda idx: f"{idx}.png", + num_examples=num_images_per_class[config["split"]], + ) + + return len(sampled_classes) * num_images_per_class[config["split"]] + +if __name__ == "__main__": + unittest.main() diff --git a/testing/ut/torchvision/test_datasets_samplers.py b/testing/ut/torchvision/test_datasets_samplers.py new file mode 100644 index 00000000..08d4818c --- /dev/null +++ b/testing/ut/torchvision/test_datasets_samplers.py @@ -0,0 +1,87 @@ +import numpy as np +import pytest +import ms_adapter.pytorch as torch +from common_utils import get_list_of_videos +from ms_adapter.torchvision import io +from ms_adapter.torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler +from ms_adapter.torchvision.datasets.video_utils import VideoClips + + +@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") +class TestDatasetsSamplers: + def test_random_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + videos = torch.div(indices, 5, rounding_mode="floor").numpy() + v_idxs, count = np.unique(videos, return_counts=True) + assert np.allclose(v_idxs, torch.tensor([0, 1, 2]).numpy()) + assert np.allclose(count, torch.tensor([3, 3, 3]).numpy()) + + def test_random_clip_sampler_unequal(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + assert len(sampler) == 2 + 3 + 3 + indices = list(iter(sampler)) + assert 0 in indices + assert 1 in indices + # remove elements of the first video, to simplify testing + indices.remove(0) + indices.remove(1) + indices = torch.tensor(indices) - 2 + videos = torch.div(indices, 5, rounding_mode="floor").numpy() + v_idxs, count = np.unique(videos, return_counts=True) + assert np.allclose(v_idxs, torch.tensor([0, 1]).numpy()) + assert np.allclose(count, torch.tensor([3, 3]).numpy()) + + def test_uniform_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + videos = torch.div(indices, 5, rounding_mode="floor").numpy() + v_idxs, count = np.unique(videos, return_counts=True) + assert np.allclose(v_idxs, torch.tensor([0, 1, 2]).numpy()) + assert np.allclose(count, torch.tensor([3, 3, 3]).numpy()) + assert np.allclose(indices.numpy(), torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]).numpy()) + + def test_uniform_clip_sampler_insufficient_clips(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + assert len(sampler) == 3 * 3 + indices = torch.tensor(list(iter(sampler))) + assert np.allclose(indices.numpy(), torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]).numpy()) + + def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25]) + video_clips = VideoClips(video_list, 5, 5) + clip_sampler = UniformClipSampler(video_clips, 3) + + distributed_sampler_rank0 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=0, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank0))) + assert len(distributed_sampler_rank0) == 6 + assert np.allclose(indices.numpy(), torch.tensor([0, 2, 4, 10, 12, 14]).numpy()) + + distributed_sampler_rank1 = DistributedSampler( + clip_sampler, + num_replicas=2, + rank=1, + group_size=3, + ) + indices = torch.tensor(list(iter(distributed_sampler_rank1))) + assert len(distributed_sampler_rank1) == 6 + assert np.allclose(indices.numpy(), torch.tensor([5, 7, 9, 0, 2, 4]).numpy()) + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/testing/ut/torchvision/test_datasets_utils.py b/testing/ut/torchvision/test_datasets_utils.py new file mode 100644 index 00000000..789828d2 --- /dev/null +++ b/testing/ut/torchvision/test_datasets_utils.py @@ -0,0 +1,246 @@ +import contextlib +import gzip +import os +import pathlib +import re +import tarfile +import zipfile + +import pytest +import ms_adapter.torchvision.datasets.utils as utils +from ms_adapter.torchvision.datasets.folder import make_dataset +from ms_adapter.torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS + +def get_file_path_2(*path_components: str) -> str: + return os.path.join(*path_components) + +TEST_FILE = get_file_path_2( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" +) + + +def patch_url_redirection(mocker, redirect_url): + class Response: + def __init__(self, url): + self.url = url + + @contextlib.contextmanager + def patched_opener(*args, **kwargs): + yield Response(redirect_url) + + return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener) + + +class TestDatasetsUtils: + # def test_get_redirect_url(self, mocker): #TODO pytorch error, too + # url = "https://url.org" + # expected_redirect_url = "https://redirect.url.org" + # + # mock = patch_url_redirection(mocker, expected_redirect_url) + # + # actual = utils._get_redirect_url(url) + # assert actual == expected_redirect_url + # + # assert mock.call_count == 2 + # call_args_1, call_args_2 = mock.call_args_list + # assert call_args_1[0][0].full_url == url + # assert call_args_2[0][0].full_url == expected_redirect_url + # + # def test_get_redirect_url_max_hops_exceeded(self, mocker): + # url = "https://url.org" + # redirect_url = "https://redirect.url.org" + # + # mock = patch_url_redirection(mocker, redirect_url) + # + # with pytest.raises(RecursionError): + # utils._get_redirect_url(url, max_hops=0) + # + # assert mock.call_count == 1 + # assert mock.call_args[0][0].full_url == url + + def test_check_md5(self): + fpath = TEST_FILE + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" + assert utils.check_md5(fpath, correct_md5) + assert not utils.check_md5(fpath, false_md5) + + def test_check_integrity(self): + existing_fpath = TEST_FILE + nonexisting_fpath = "" + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" + assert utils.check_integrity(existing_fpath, correct_md5) + assert not utils.check_integrity(existing_fpath, false_md5) + assert utils.check_integrity(existing_fpath) + assert not utils.check_integrity(nonexisting_fpath) + + def test_get_google_drive_file_id(self): + url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view" + expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV" + + actual = utils._get_google_drive_file_id(url) + assert actual == expected + + def test_get_google_drive_file_id_invalid_url(self): + url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz" + + assert utils._get_google_drive_file_id(url) is None + + @pytest.mark.parametrize( + "file, expected", + [ + ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tbz", (".tbz", ".tar", ".bz2")), + ("foo.tbz2", (".tbz2", ".tar", ".bz2")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.bz2", (".bz2", None, ".bz2")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), + ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.bar.gz", (".gz", None, ".gz")), + ("foo.bar.zip", (".zip", ".zip", None)), + ], + ) + def test_detect_file_type(self, file, expected): + assert utils._detect_file_type(file) == expected + + @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"]) + def test_detect_file_type_incompatible(self, file): + # tests detect file type for no extension, unknown compression and unknown partial extension + with pytest.raises(RuntimeError): + utils._detect_file_type(file) + + @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"]) + def test_decompress(self, extension, tmpdir): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}{extension}" + compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension] + + with compressed_file_opener(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + compressed, file, content = create_compressed(tmpdir) + + utils._decompress(compressed) + + assert os.path.exists(file) + + with open(file) as fh: + assert fh.read() == content + + def test_decompress_no_compression(self): + with pytest.raises(RuntimeError): + utils._decompress("foo.tar") + + def test_decompress_remove_finished(self, tmpdir): + def create_compressed(root, content="this is the content"): + file = os.path.join(root, "file") + compressed = f"{file}.gz" + + with gzip.open(compressed, "wb") as fh: + fh.write(content.encode()) + + return compressed, file, content + + compressed, file, content = create_compressed(tmpdir) + + utils.extract_archive(compressed, tmpdir, remove_finished=True) + + assert not os.path.exists(compressed) + + # @pytest.mark.parametrize("extension", [".gz", ".xz"]) + # @pytest.mark.parametrize("remove_finished", [True, False]) + # def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): + # filename = "foo" + # file = f"{filename}{extension}" + # + # mocked = mocker.patch("torchvision.datasets.utils._decompress") + # utils.extract_archive(file, remove_finished=remove_finished) + # + # mocked.assert_called_once_with(file, filename, remove_finished=remove_finished) + + def test_extract_zip(self, tmpdir): + def create_archive(root, content="this is the content"): + file = os.path.join(root, "dst.txt") + archive = os.path.join(root, "archive.zip") + + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr(os.path.basename(file), content) + + return archive, file, content + + archive, file, content = create_archive(tmpdir) + + utils.extract_archive(archive, tmpdir) + + assert os.path.exists(file) + + with open(file) as fh: + assert fh.read() == content + + @pytest.mark.parametrize( + "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")] + ) + def test_extract_tar(self, extension, mode, tmpdir): + def create_archive(root, extension, mode, content="this is the content"): + src = os.path.join(root, "src.txt") + dst = os.path.join(root, "dst.txt") + archive = os.path.join(root, f"archive{extension}") + + with open(src, "w") as fh: + fh.write(content) + + with tarfile.open(archive, mode=mode) as fh: + fh.add(src, arcname=os.path.basename(dst)) + + return archive, dst, content + + archive, file, content = create_archive(tmpdir, extension, mode) + + utils.extract_archive(archive, tmpdir) + + assert os.path.exists(file) + + with open(file) as fh: + assert fh.read() == content + + def test_verify_str_arg(self): + assert "a" == utils.verify_str_arg("a", "arg", ("a",)) + pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg") + pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") + + +@pytest.mark.parametrize( + ("kwargs", "expected_error_msg"), + [ + (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), + (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")), + (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")), + ], +) +def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg): + tmpdir = pathlib.Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "a" / "a.png").touch() + + (tmpdir / "b").mkdir() + (tmpdir / "b" / "b.jpeg").touch() + + (tmpdir / "c").mkdir() + (tmpdir / "c" / "c.unknown").touch() + + with pytest.raises(FileNotFoundError, match=expected_error_msg): + make_dataset(str(tmpdir), **kwargs) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/testing/ut/torchvision/test_datasets_video_utils.py b/testing/ut/torchvision/test_datasets_video_utils.py new file mode 100644 index 00000000..2281b60a --- /dev/null +++ b/testing/ut/torchvision/test_datasets_video_utils.py @@ -0,0 +1,105 @@ +import pytest +import numpy as np +import ms_adapter.pytorch as torch +from common_utils import get_list_of_videos +from ms_adapter.torchvision import io +from ms_adapter.torchvision.datasets.video_utils import unfold, VideoClips + + +class TestVideo: + def test_unfold(self): + a = torch.tensor([0,1,2,3,4,5,6]) + r = unfold(a, 3, 3, 1) + expected = torch.tensor( + [ + [0, 1, 2], + [3, 4, 5], + ] + ) + assert np.allclose(r.numpy(), expected.numpy()) + + r = unfold(a, 3, 2, 1) + expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]]) + assert np.allclose(r.numpy(), expected.numpy()) + + r = unfold(a, 3, 2, 2) + expected = torch.tensor( + [ + [0, 2, 4], + [2, 4, 6], + ] + ) + assert np.allclose(r.numpy(), expected.numpy()) + + @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + def test_video_clips(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3) + video_clips = VideoClips(video_list, 5, 5, num_workers=2) + assert video_clips.num_clips() == 1 + 2 + 3 + for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx + + video_clips = VideoClips(video_list, 6, 6) + assert video_clips.num_clips() == 0 + 1 + 2 + for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]): + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx + + video_clips = VideoClips(video_list, 6, 1) + assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1) + for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]: + video_idx, clip_idx = video_clips.get_clip_location(i) + assert video_idx == v_idx + assert clip_idx == c_idx + + @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") + def test_video_clips_custom_fps(self, tmpdir): + video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) + num_frames = 4 + for fps in [1, 3, 4, 10]: + video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2) + for i in range(video_clips.num_clips()): + video, audio, info, video_idx = video_clips.get_clip(i) + assert video.shape[0] == num_frames + assert info["video_fps"] == fps + # TODO add tests checking that the content is right + + def test_compute_clips_for_video(self): + video_pts = torch.tensor(np.arange(30)) + # case 1: single clip + num_frames = 13 + orig_fps = 30 + duration = float(len(video_pts)) / orig_fps + new_fps = 13 + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) + resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) + assert len(clips) == 1 + assert np.allclose(clips.numpy(), idxs.numpy()) + assert np.allclose(idxs[0].numpy(), resampled_idxs.numpy()) + + # case 2: all frames appear only once + num_frames = 4 + orig_fps = 30 + duration = float(len(video_pts)) / orig_fps + new_fps = 12 + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) + resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) + assert len(clips) == 3 + assert np.allclose(clips.numpy(), idxs.numpy()) + assert np.allclose(idxs.flatten().numpy(), resampled_idxs.numpy()) + + # case 3: frames aren't enough for a clip + num_frames = 32 + orig_fps = 30 + new_fps = 13 + with pytest.warns(UserWarning): + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) + assert len(clips) == 0 + assert len(idxs) == 0 + + +if __name__ == "__main__": + pytest.main([__file__])