#212 add torchvision datasets

Merged
Erpim merged 4 commits from hanjr-patch into master 1 year ago
  1. +1
    -1
      ms_adapter/pytorch/utils/__init__.py
  2. +5
    -5
      ms_adapter/torchvision/__init__.py
  3. +64
    -54
      ms_adapter/torchvision/_internally_replaced_utils.py
  4. +5
    -11
      ms_adapter/torchvision/datasets/_optical_flow.py
  5. +12
    -2
      ms_adapter/torchvision/datasets/caltech.py
  6. +42
    -22
      ms_adapter/torchvision/datasets/celeba.py
  7. +3
    -0
      ms_adapter/torchvision/datasets/cifar.py
  8. +12
    -11
      ms_adapter/torchvision/datasets/fakedata.py
  9. +1
    -1
      ms_adapter/torchvision/datasets/fer2013.py
  10. +1
    -2
      ms_adapter/torchvision/datasets/hmdb51.py
  11. +11
    -6
      ms_adapter/torchvision/datasets/imagenet.py
  12. +1
    -3
      ms_adapter/torchvision/datasets/kinetics.py
  13. +18
    -18
      ms_adapter/torchvision/datasets/mnist.py
  14. +4
    -1
      ms_adapter/torchvision/datasets/pcam.py
  15. +25
    -22
      ms_adapter/torchvision/datasets/phototour.py
  16. +2
    -2
      ms_adapter/torchvision/datasets/samplers/clip_sampler.py
  17. +41
    -41
      ms_adapter/torchvision/datasets/video_utils.py
  18. +1
    -4
      ms_adapter/torchvision/extension.py
  19. +5
    -10
      ms_adapter/torchvision/io/__init__.py
  20. +7
    -6
      ms_adapter/torchvision/io/_load_gpu_decoder.py
  21. +8
    -8
      ms_adapter/torchvision/io/_video_opt.py
  22. +51
    -41
      ms_adapter/torchvision/io/image.py
  23. +9
    -10
      ms_adapter/torchvision/io/video.py
  24. +2
    -2
      ms_adapter/torchvision/io/video_reader.py
  25. +19
    -18
      testing/ut/torchvision/common_utils.py
  26. +982
    -0
      testing/ut/torchvision/datasets_utils.py
  27. +2654
    -0
      testing/ut/torchvision/test_datasets.py
  28. +87
    -0
      testing/ut/torchvision/test_datasets_samplers.py
  29. +246
    -0
      testing/ut/torchvision/test_datasets_utils.py
  30. +105
    -0
      testing/ut/torchvision/test_datasets_video_utils.py

+ 1
- 1
ms_adapter/pytorch/utils/__init__.py View File

@@ -1 +1 @@
# from ms_adapter.pytorch.utils import data
from ms_adapter.pytorch.utils import data

+ 5
- 5
ms_adapter/torchvision/__init__.py View File

@@ -2,12 +2,12 @@ import os
import warnings

# import torch
# from torchvision import datasets
# from torchvision import io
# from torchvision import models
# from torchvision import ops
from ms_adapter.torchvision import datasets
from ms_adapter.torchvision import io
# from ms_adapter.torchvision import models
# from ms_adapter.torchvision import ops
from ms_adapter.torchvision import transforms
# from torchvision import utils
# from ms_adapter.torchvision import utils

from .extension import _HAS_OPS



+ 64
- 54
ms_adapter/torchvision/_internally_replaced_utils.py View File

@@ -1,58 +1,68 @@
# import importlib.machinery
# import os
#
# from torch.hub import _get_torch_home
#
#
# _HOME = os.path.join(_get_torch_home(), "datasets", "vision")
# _USE_SHARDED_DATASETS = False
#
#
# def _download_file_from_remote_location(fpath: str, url: str) -> None:
# pass
#
#
# def _is_remote_location_available() -> bool:
# return False
#
#
import importlib.machinery
import os



ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')))
return torch_home

_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False



def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass


def _is_remote_location_available() -> bool:
return False


# try:
# from torch.hub import load_state_dict_from_url # noqa: 401
# except ImportError:
# from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
#
#
# def _get_extension_path(lib_name):
#
# lib_dir = os.path.dirname(__file__)
# if os.name == "nt":
# # Register the main torchvision library location on the default DLL path
# import ctypes
# import sys
#
# kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
# with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
# prev_error_mode = kernel32.SetErrorMode(0x0001)
#
# if with_load_library_flags:
# kernel32.AddDllDirectory.restype = ctypes.c_void_p
#
# if sys.version_info >= (3, 8):
# os.add_dll_directory(lib_dir)
# elif with_load_library_flags:
# res = kernel32.AddDllDirectory(lib_dir)
# if res is None:
# err = ctypes.WinError(ctypes.get_last_error())
# err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
# raise err
#
# kernel32.SetErrorMode(prev_error_mode)
#
# loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
#
# extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
# ext_specs = extfinder.find_spec(lib_name)
# if ext_specs is None:
# raise ImportError
#
# return ext_specs.origin
def _get_extension_path(lib_name):
lib_dir = os.path.dirname(__file__)
if os.name == "nt":
# Register the main torchvision library location on the default DLL path
import ctypes
import sys
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
prev_error_mode = kernel32.SetErrorMode(0x0001)
if with_load_library_flags:
kernel32.AddDllDirectory.restype = ctypes.c_void_p
if sys.version_info >= (3, 8):
os.add_dll_directory(lib_dir)
elif with_load_library_flags:
res = kernel32.AddDllDirectory(lib_dir)
if res is None:
err = ctypes.WinError(ctypes.get_last_error())
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.'
raise err
kernel32.SetErrorMode(prev_error_mode)
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
ext_specs = extfinder.find_spec(lib_name)
if ext_specs is None:
raise ImportError
return ext_specs.origin

+ 5
- 11
ms_adapter/torchvision/datasets/_optical_flow.py View File

@@ -6,7 +6,8 @@ from glob import glob
from pathlib import Path

import numpy as np
import torch
# import torch
import ms_adapter.pytorch as torch
from PIL import Image

from ..io.image import _read_png_16
@@ -27,12 +28,12 @@ class FlowDataset(ABC, VisionDataset):
# Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root, transforms=None):

super().__init__(root=root)
self.transforms = transforms
self._has_builtin_flow_mask = False

self._flow_list = []
self._image_list = []
@@ -115,8 +116,6 @@ class Sintel(FlowDataset):
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root, split="train", pass_name="clean", transforms=None):
@@ -179,13 +178,12 @@ class KittiFlow(FlowDataset):
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
self._has_builtin_flow_mask = True
root = Path(root) / "KittiFlow" / (split + "ing")
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
@@ -242,8 +240,6 @@ class FlyingChairs(FlowDataset):
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root, split="train", transforms=None):
@@ -313,8 +309,6 @@ class FlyingThings3D(FlowDataset):
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None):
@@ -400,10 +394,10 @@ class HD1K(FlowDataset):
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
self._has_builtin_flow_mask = True

verify_str_arg(split, "split", valid_values=("train", "test"))



+ 12
- 2
ms_adapter/torchvision/datasets/caltech.py View File

@@ -1,5 +1,6 @@
import os
import os.path
import warnings
from typing import Any, Callable, List, Optional, Union, Tuple

from PIL import Image
@@ -127,7 +128,9 @@ class Caltech101(VisionDataset):
if self._check_integrity():
print("Files already downloaded and verified")
return

warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, "
"please download it from the official website."
"Caltech 101 <https://data.caltech.edu/records/20086>")
download_and_extract_archive(
"https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
self.root,
@@ -229,9 +232,16 @@ class Caltech256(VisionDataset):
print("Files already downloaded and verified")
return

# download_and_extract_archive(
# "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
# self.root,
# filename="256_ObjectCategories.tar",
# md5="67b4f42ca05d46448c6bb8ecd2220f6d",
# )
download_and_extract_archive(
"https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
"https://data.caltech.edu/records/nyy15-4j048/files/256_ObjectCategories.tar?download=1",
self.root,
filename="256_ObjectCategories.tar",
md5="67b4f42ca05d46448c6bb8ecd2220f6d",
)


+ 42
- 22
ms_adapter/torchvision/datasets/celeba.py View File

@@ -1,11 +1,12 @@
import csv
import os
import warnings
from collections import namedtuple
from typing import Any, Callable, List, Optional, Union, Tuple

import PIL
import torch
import ms_adapter.pytorch as torch
import mindspore as ms
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
from .vision import VisionDataset

@@ -40,22 +41,9 @@ class CelebA(VisionDataset):
downloaded again.
"""

base_folder = "celeba"
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
# right now.
file_list = [
# File ID MD5 Hash Filename
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]

def __init__(
self,
@@ -67,6 +55,19 @@ class CelebA(VisionDataset):
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.file_list = [
# File ID MD5 Hash Filename
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc","b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
self.base_folder = "celeba"
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
@@ -94,17 +95,28 @@ class CelebA(VisionDataset):
bbox = self._load_csv("list_bbox_celeba.txt", header=1)
landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
attr = self._load_csv("list_attr_celeba.txt", header=1)

self.old_identity = identity
self.old_attr = attr
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()

if mask == slice(None): # if split == "all"
self.filename = splits.index
else:
self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]
self.identity = identity.data[mask]
self.bbox = bbox.data[mask]
self.landmarks_align = landmarks_align.data[mask]
self.attr = attr.data[mask]
if mask != slice(None):
self.identity = torch.tensor(ms.ops.masked_select(identity.data, ms.ops.expand_dims(mask, -1))).reshape(-1,identity.data.shape[1])
self.bbox = torch.tensor(ms.ops.masked_select(bbox.data,ms.ops.expand_dims(mask, -1))).reshape(-1, bbox.data.shape[1])
self.landmarks_align = torch.tensor(ms.ops.masked_select(landmarks_align.data, ms.ops.expand_dims(mask, -1))).reshape(-1,landmarks_align.data.shape[1])
self.attr = torch.tensor(ms.ops.masked_select(attr.data, ms.ops.expand_dims(mask, -1))).reshape(-1,attr.data.shape[1])
else:
self.identity = torch.tensor(identity.data)
self.bbox = torch.tensor(bbox.data)
self.landmarks_align = torch.tensor(landmarks_align.data)
self.attr = torch.tensor(attr.data)
# self.identity = identity.data[mask]
# self.bbox = bbox.data[mask]
# self.landmarks_align = landmarks_align.data[mask]
# self.attr = attr.data[mask]

# map from {-1, 1} to {0, 1}
self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
self.attr_names = attr.header
@@ -145,6 +157,9 @@ class CelebA(VisionDataset):
if self._check_integrity():
print("Files already downloaded and verified")
return
warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, "
"please download it from the official website."
"Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>")

for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
@@ -153,7 +168,12 @@ class CelebA(VisionDataset):

def __getitem__(self, index: int) -> Tuple[Any, Any]:
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

# print("self.landmarks_align", self.landmarks_align.shape)
# print("self.old_identity", self.old_identity.data.shape)
# print("self.bbox", self.bbox.shape)
# print("mask", self.mask.shape)
# print("old_attr.data", self.old_attr.data.shape)
# print("self.attr", self.attr.shape)
target: Any = []
for t in self.target_type:
if t == "attr":


+ 3
- 0
ms_adapter/torchvision/datasets/cifar.py View File

@@ -78,8 +78,11 @@ class CIFAR10(VisionDataset):
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
print("file_path", file_path)
with open(file_path, "rb") as f:
print("==============================5")
entry = pickle.load(f, encoding="latin1")
print("==============================6")
self.data.append(entry["data"])
if "labels" in entry:
self.targets.extend(entry["labels"])


+ 12
- 11
ms_adapter/torchvision/datasets/fakedata.py View File

@@ -1,9 +1,9 @@
from typing import Any, Callable, Optional, Tuple
import numpy as np
import ms_adapter.pytorch as torch

import torch

from .. import transforms
from .vision import VisionDataset
from ms_adapter.torchvision import transforms
from ms_adapter.torchvision.datasets.vision import VisionDataset


class FakeData(VisionDataset):
@@ -25,7 +25,7 @@ class FakeData(VisionDataset):
def __init__(
self,
size: int = 1000,
image_size: Tuple[int, int, int] = (3, 224, 224),
image_size: Tuple[int, int, int] = (224, 224, 3),
num_classes: int = 10,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
@@ -48,14 +48,15 @@ class FakeData(VisionDataset):
# create random image that is consistent with the index id
if index >= len(self):
raise IndexError(f"{self.__class__.__name__} index out of range")
rng_state = torch.get_rng_state()
torch.manual_seed(index + self.random_offset)
img = torch.randn(*self.image_size)
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
torch.set_rng_state(rng_state)
rng = np.random.RandomState()
rng_state = rng.get_state()
rng.seed(index + self.random_offset)
img = np.asarray(rng.randn(*self.image_size), dtype=np.uint8)
target = rng.randint(0, self.num_classes, size=(1,), dtype=np.long)[0]
rng.set_state(rng_state)
# convert to PIL Image
img = transforms.ToPILImage()(img)

if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:


+ 1
- 1
ms_adapter/torchvision/datasets/fer2013.py View File

@@ -2,7 +2,7 @@ import csv
import pathlib
from typing import Any, Callable, Optional, Tuple

import torch
import ms_adapter.pytorch as torch
from PIL import Image

from .utils import verify_str_arg, check_integrity


+ 1
- 2
ms_adapter/torchvision/datasets/hmdb51.py View File

@@ -2,7 +2,6 @@ import glob
import os
from typing import Optional, Callable, Tuple, Dict, Any, List

from torch import Tensor

from .folder import find_classes, make_dataset
from .video_utils import VideoClips
@@ -140,7 +139,7 @@ class HMDB51(VisionDataset):
def __len__(self) -> int:
return self.video_clips.num_clips()

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
def __getitem__(self, idx: int):
video, audio, _, video_idx = self.video_clips.get_clip(idx)
sample_index = self.indices[video_idx]
_, class_index = self.samples[sample_index]


+ 11
- 6
ms_adapter/torchvision/datasets/imagenet.py View File

@@ -1,10 +1,10 @@
import os
import shutil
import tempfile
import pickle
from contextlib import contextmanager
from typing import Any, Dict, List, Iterator, Optional, Tuple

import torch

from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg
@@ -43,9 +43,10 @@ class ImageNet(ImageFolder):
root = self.root = os.path.expanduser(root)
self.split = verify_str_arg(split, "split", ("train", "val"))

self.parse_archives()
self.parse_archives() #TODO save dataset into MindRecord
wnid_to_classes = load_meta_file(self.root)[0]

# wnid_to_classes, self.val_wnid= parse_devkit_archive(self.root)
super().__init__(self.split_folder, **kwargs)
self.root = root

@@ -78,7 +79,9 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
file = os.path.join(root, file)

if check_integrity(file):
return torch.load(file)
with open(file, "rb") as f:
data = pickle.load(f)
return data
else:
msg = (
"The meta file {} is not present in the root directory or is corrupted. "
@@ -96,7 +99,7 @@ def _verify_archive(root: str, file: str, md5: str) -> None:
raise RuntimeError(msg.format(file, root))


def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
def parse_devkit_archive(root: str, file: Optional[str] = None):
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.

@@ -147,8 +150,10 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
val_idcs = parse_val_groundtruth_txt(devkit_root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]

torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))

with open(os.path.join(root, META_FILE), "wb") as f:
pickle.dump((wnid_to_classes, val_wnids), f)
# torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
# return wnid_to_classes, val_wnids

def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and


+ 1
- 3
ms_adapter/torchvision/datasets/kinetics.py View File

@@ -8,8 +8,6 @@ from multiprocessing import Pool
from os import path
from typing import Any, Callable, Dict, Optional, Tuple

from torch import Tensor

from .folder import find_classes, make_dataset
from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity
from .video_utils import VideoClips
@@ -238,7 +236,7 @@ class Kinetics(VisionDataset):
def __len__(self) -> int:
return self.video_clips.num_clips()

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
def __getitem__(self, idx: int):
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1]



+ 18
- 18
ms_adapter/torchvision/datasets/mnist.py View File

@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError

import numpy as np
import torch
import ms_adapter.pytorch as torch
from PIL import Image

from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
@@ -91,9 +91,9 @@ class MNIST(VisionDataset):
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set

if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
# if self._check_legacy_exist():
# self.data, self.targets = self._load_legacy_data()
# return

if download:
self.download()
@@ -103,20 +103,20 @@ class MNIST(VisionDataset):

self.data, self.targets = self._load_data()

def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
# def _check_legacy_exist(self):
# processed_folder_exists = os.path.exists(self.processed_folder)
# if not processed_folder_exists:
# return False
#
# return all(
# check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
# )
# def _load_legacy_data(self):
# # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# # directly.
# data_file = self.training_file if self.train else self.test_file
# return torch.load(os.path.join(self.processed_folder, data_file))

def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"


+ 4
- 1
ms_adapter/torchvision/datasets/pcam.py View File

@@ -1,4 +1,5 @@
import pathlib
import warnings
from typing import Any, Callable, Optional, Tuple

from PIL import Image
@@ -123,7 +124,9 @@ class PCAM(VisionDataset):
def _download(self) -> None:
if self._check_exists():
return

warnings.warn("The dataset is stored on google drive, if you can't download it from google drive, "
"please download it from the official website."
"PCAM Dataset <https://github.com/basveeling/pcam>")
for file_name, file_id, md5 in self._FILES[self._split].values():
archive_name = file_name + ".gz"
download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)


+ 25
- 22
ms_adapter/torchvision/datasets/phototour.py View File

@@ -2,7 +2,7 @@ import os
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import torch
import ms_adapter.pytorch as torch
from PIL import Image

from .utils import download_url
@@ -93,7 +93,7 @@ class PhotoTour(VisionDataset):
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, f"{name}.zip")
self.data_file = os.path.join(self.root, f"{name}.pt")
# self.data_file = os.path.join(self.root, f"{name}.pt")

self.train = train
self.mean = self.means[name]
@@ -102,11 +102,11 @@ class PhotoTour(VisionDataset):
if download:
self.download()

if not self._check_datafile_exists():
self.cache()
# if not self._check_datafile_exists():
# self.cache()
self.data, self.labels, self.matches = self.cache()
# load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file)
# self.data, self.labels, self.matches = torch.load(self.data_file)

def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
"""
@@ -131,16 +131,16 @@ class PhotoTour(VisionDataset):
def __len__(self) -> int:
return len(self.data if self.train else self.matches)

def _check_datafile_exists(self) -> bool:
return os.path.exists(self.data_file)
# def _check_datafile_exists(self) -> bool:
# return os.path.exists(self.data_file)

def _check_downloaded(self) -> bool:
return os.path.exists(self.data_dir)

def download(self) -> None:
if self._check_datafile_exists():
print(f"# Found cached data {self.data_file}")
return
# if self._check_datafile_exists():
# print(f"# Found cached data {self.data_file}")
# return

if not self._check_downloaded():
# download files
@@ -160,18 +160,21 @@ class PhotoTour(VisionDataset):

os.unlink(fpath)

def cache(self) -> None:
def cache(self):
# process and save as torch files
print(f"# Caching data {self.data_file}")

dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files),
)

with open(self.data_file, "wb") as f:
torch.save(dataset, f)
# print(f"# Caching data {self.data_file}")

# dataset = (
# read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
# read_info_file(self.data_dir, self.info_file),
# read_matches_files(self.data_dir, self.matches_files),
# )
data = read_image_file(self.data_dir, self.image_ext, self.lens[self.name])
info = read_info_file(self.data_dir, self.info_file)
matches = read_matches_files(self.data_dir, self.matches_files)
return data, info, matches
# with open(self.data_file, "wb") as f:
# torch.save(dataset, f)

def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"


+ 2
- 2
ms_adapter/torchvision/datasets/samplers/clip_sampler.py View File

@@ -88,7 +88,7 @@ class DistributedSampler(Sampler):

# subsample
indices = indices[self.rank : total_group_size : self.num_replicas, :]
indices = torch.reshape(indices, (-1,)).tolist()
indices = torch.reshape(indices, (-1,)).numpy().tolist()
assert len(indices) == self.num_samples

if isinstance(self.dataset, Sampler):
@@ -134,7 +134,7 @@ class UniformClipSampler(Sampler):
sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64)
s += length
idxs.append(sampled)
return iter(cast(List[int], torch.cat(idxs).tolist()))
return iter(cast(List[int], torch.cat(idxs).numpy().tolist()))

def __len__(self) -> int:
return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0)


+ 41
- 41
ms_adapter/torchvision/datasets/video_utils.py View File

@@ -4,10 +4,10 @@ import warnings
from fractions import Fraction
from typing import Any, Dict, List, Optional, Callable, Union, Tuple, TypeVar, cast

import torch
from torchvision.io import (
_probe_video_from_file,
_read_video_from_file,
import ms_adapter.pytorch as torch
from ms_adapter.torchvision.io import (
# _probe_video_from_file,
# _read_video_from_file,
read_video,
read_video_timestamps,
)
@@ -145,9 +145,9 @@ class VideoClips:

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
import torch.utils.data
import ms_adapter.pytorch.utils.data

dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
dl: ms_adapter.pytorch.utils.data.DataLoader = ms_adapter.pytorch.utils.data.DataLoader(
_VideoTimestampsDataset(self.video_paths), # type: ignore[arg-type]
batch_size=16,
num_workers=self.num_workers,
@@ -252,7 +252,7 @@ class VideoClips:
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
self.cumulative_sizes = clip_lengths.cumsum(0).asnumpy().tolist()

def __len__(self) -> int:
return self.num_clips()
@@ -286,7 +286,7 @@ class VideoClips:
# advanced indexing
step = int(step)
return slice(None, None, step)
idxs = torch.arange(num_frames, dtype=torch.float32) * step
idxs = torch.arange(start=0, end = num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return idxs

@@ -330,39 +330,39 @@ class VideoClips:
start_pts = clip_pts[0].item()
end_pts = clip_pts[-1].item()
video, audio, info = read_video(video_path, start_pts, end_pts)
else:
_info = _probe_video_from_file(video_path)
video_fps = _info.video_fps
audio_fps = None
video_start_pts = cast(int, clip_pts[0].item())
video_end_pts = cast(int, clip_pts[-1].item())
audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
if _info.has_audio:
audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_fps = _info.audio_sample_rate
video, audio, _ = _read_video_from_file(
video_path,
video_width=self._video_width,
video_height=self._video_height,
video_min_dimension=self._video_min_dimension,
video_max_dimension=self._video_max_dimension,
video_pts_range=(video_start_pts, video_end_pts),
video_timebase=video_timebase,
audio_samples=self._audio_samples,
audio_channels=self._audio_channels,
audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase,
)
info = {"video_fps": video_fps}
if audio_fps is not None:
info["audio_fps"] = audio_fps
# else: #TODO just support pyav
# _info = _probe_video_from_file(video_path)
# video_fps = _info.video_fps
# audio_fps = None
#
# video_start_pts = cast(int, clip_pts[0].item())
# video_end_pts = cast(int, clip_pts[-1].item())
#
# audio_start_pts, audio_end_pts = 0, -1
# audio_timebase = Fraction(0, 1)
# video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
# if _info.has_audio:
# audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
# audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
# audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
# audio_fps = _info.audio_sample_rate
# video, audio, _ = _read_video_from_file(
# video_path,
# video_width=self._video_width,
# video_height=self._video_height,
# video_min_dimension=self._video_min_dimension,
# video_max_dimension=self._video_max_dimension,
# video_pts_range=(video_start_pts, video_end_pts),
# video_timebase=video_timebase,
# audio_samples=self._audio_samples,
# audio_channels=self._audio_channels,
# audio_pts_range=(audio_start_pts, audio_end_pts),
# audio_timebase=audio_timebase,
# )
#
# info = {"video_fps": video_fps}
# if audio_fps is not None:
# info["audio_fps"] = audio_fps

if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx]


+ 1
- 4
ms_adapter/torchvision/extension.py View File

@@ -2,10 +2,7 @@ import ctypes
import os
import sys
from warnings import warn
#
# import torch
#
# from ._internally_replaced_utils import _get_extension_path
from ._internally_replaced_utils import _get_extension_path
#
#
_HAS_OPS = False


+ 5
- 10
ms_adapter/torchvision/io/__init__.py View File

@@ -1,9 +1,4 @@
from typing import Any, Dict, Iterator

import torch

from ..utils import _log_api_usage_once

try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
@@ -37,7 +32,7 @@ from .video import (
read_video_timestamps,
write_video,
)
from .video_reader import VideoReader
# from .video_reader import VideoReader


__all__ = [
@@ -52,8 +47,8 @@ __all__ = [
"_probe_video_from_memory",
"_HAS_VIDEO_OPT",
"_HAS_GPU_VIDEO_DECODER",
"_read_video_clip_from_memory",
"_read_video_meta_data",
# "_read_video_clip_from_memory", # TODO depend on torch.ops
# "_read_video_meta_data",# TODO depend on torch.ops
"VideoMetaData",
"Timebase",
"ImageReadMode",
@@ -67,6 +62,6 @@ __all__ = [
"write_file",
"write_jpeg",
"write_png",
"Video",
"VideoReader",
# "Video",
# "VideoReader", # TODO depend on torch.ops
]

+ 7
- 6
ms_adapter/torchvision/io/_load_gpu_decoder.py View File

@@ -1,8 +1,9 @@
from ..extension import _load_library
# from ..extension import _load_library


try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError):
_HAS_GPU_VIDEO_DECODER = False
# try:
# _load_library("Decoder")
# _HAS_GPU_VIDEO_DECODER = True
# except (ImportError, OSError):
# _HAS_GPU_VIDEO_DECODER = False
_HAS_GPU_VIDEO_DECODER = False

+ 8
- 8
ms_adapter/torchvision/io/_video_opt.py View File

@@ -3,17 +3,17 @@ import warnings
from fractions import Fraction
from typing import List, Tuple, Dict, Optional, Union

import torch

from ..extension import _load_library
import ms_adapter.pytorch as torch
# from ..extension import _load_library


try:
_load_library("video_reader")
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
_HAS_VIDEO_OPT = False
# try:
# _load_library("video_reader")
# _HAS_VIDEO_OPT = True
# except (ImportError, OSError):
# _HAS_VIDEO_OPT = False
_HAS_VIDEO_OPT = False
default_timebase = Fraction(0, 1)




+ 51
- 41
ms_adapter/torchvision/io/image.py View File

@@ -1,16 +1,13 @@
from enum import Enum
from warnings import warn

import torch
import ms_adapter.pytorch as torch
from mindspore.dataset import vision

from ..extension import _load_library
from ..utils import _log_api_usage_once


try:
_load_library("image")
except (ImportError, OSError) as e:
warn(f"Failed to load image Python extension: {e}")
# try:
# _load_library("image")
# except (ImportError, OSError) as e:
# warn(f"Failed to load image Python extension: {e}")


class ImageReadMode(Enum):
@@ -42,9 +39,9 @@ def read_file(path: str) -> torch.Tensor:
Returns:
data (Tensor)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_file)
data = torch.ops.image.read_file(path)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(read_file)
data = vision.read_file(path)
return data


@@ -57,9 +54,9 @@ def write_file(filename: str, data: torch.Tensor) -> None:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_file)
torch.ops.image.write_file(filename, data)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(write_file)
vision.write_file(filename, data)


def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
@@ -79,9 +76,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_png)
output = torch.ops.image.decode_png(input, mode.value, False)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(decode_png)
# output = vision.decode_png(input, mode.value, False)
output = vision.Decode()(input)
output = torch.tensor(output, dtype=torch.uint8)
output = output.permute(2, 0, 1)
return output


@@ -100,9 +100,9 @@ def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
PNG file.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_png)
output = torch.ops.image.encode_png(input, compression_level)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(encode_png)
output = vision.encode_png(input, compression_level)
return output


@@ -118,8 +118,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
compression_level (int): Compression factor for the resulting file, it must be a number
between 0 and 9. Default: 6
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_png)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(write_png)
output = encode_png(input, compression_level)
write_file(filename, output)

@@ -154,13 +154,16 @@ def decode_jpeg(
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_jpeg)
device = torch.device(device)
if device.type == "cuda":
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(decode_jpeg)
# device = torch.device(device)
# if device.type == "cuda":
# output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
# else:
# output = torch.ops.image.decode_jpeg(input, mode.value)
output = vision.Decode()(input)
output = torch.tensor(output, dtype=torch.uint8)
output = output.permute(2, 0, 1)
return output


@@ -179,12 +182,12 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
JPEG file.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_jpeg)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(encode_jpeg)
if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number between 1 and 100")

output = torch.ops.image.encode_jpeg(input, quality)
output = vision.encode_jpeg(input, quality)
return output


@@ -199,8 +202,8 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
quality (int): Quality of the resulting JPEG file, it must be a number
between 1 and 100. Default: 75
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_jpeg)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(write_jpeg)
output = encode_jpeg(input, quality)
write_file(filename, output)

@@ -224,9 +227,12 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
output = torch.ops.image.decode_image(input, mode.value)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(decode_image)
# output = vision.decode_image(input, mode.value)
output = vision.Decode()(input)
output = torch.tensor(output, dtype=torch.uint8)
output = output.permute(2, 0, 1)
return output


@@ -246,12 +252,16 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(read_image)
data = read_file(path)
return decode_image(data, mode)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)
# return vision.decode_png(data, mode.value, True)
output = vision.Decode()(data)
output = torch.tensor(output, dtype=torch.uint8)
output = output.permute(2, 0, 1)
return output

+ 9
- 10
ms_adapter/torchvision/io/video.py View File

@@ -7,9 +7,8 @@ from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import ms_adapter.pytorch as torch

from ..utils import _log_api_usage_once
from . import _video_opt


@@ -78,8 +77,8 @@ def write_video(
audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_video)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(write_video)
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

@@ -260,14 +259,14 @@ def read_video(
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(read_video)

output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")

from torchvision import get_video_backend
from ms_adapter.torchvision import get_video_backend

if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
@@ -381,9 +380,9 @@ def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[in
video_fps (float, optional): the frame rate for the video

"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video_timestamps)
from torchvision import get_video_backend
# if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# _log_api_usage_once(read_video_timestamps)
from ms_adapter.torchvision import get_video_backend

if get_video_backend() != "pyav":
return _video_opt._read_video_timestamps(filename, pts_unit)


+ 2
- 2
ms_adapter/torchvision/io/video_reader.py View File

@@ -2,7 +2,7 @@ from typing import Any, Dict, Iterator

import torch

from ..utils import _log_api_usage_once

try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
@@ -91,7 +91,7 @@ class VideoReader:
"""

def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None:
_log_api_usage_once(self)
# _log_api_usage_once(self)
self.is_cuda = False
device = torch.device(device)
if device.type == "cuda":


+ 19
- 18
testing/ut/torchvision/common_utils.py View File

@@ -10,7 +10,7 @@ import ms_adapter.pytorch as torch
from ms_adapter.pytorch.tensor import cast_to_adapter_tensor

from PIL import Image
# from torchvision import io
from ms_adapter.torchvision import io

import __main__ # noqa: 401

@@ -143,23 +143,24 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
# assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)


# def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
# names = []
# for i in range(num_videos):
# if sizes is None:
# size = 5 * (i + 1)
# else:
# size = sizes[i]
# if fps is None:
# f = 5
# else:
# f = fps[i]
# data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
# name = os.path.join(tmpdir, f"{i}.mp4")
# names.append(name)
# io.write_video(name, data, fps=f)
#
# return names
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
print(i)
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, f"{i}.mp4")
names.append(name)
io.write_video(name, data, fps=f)

return names


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):


+ 982
- 0
testing/ut/torchvision/datasets_utils.py View File

@@ -0,0 +1,982 @@
import contextlib
import functools
import importlib
import inspect
import itertools
import os
import pathlib
import random
import shutil
import string
import struct
import tarfile
import unittest
import unittest.mock
import zipfile
from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np

import PIL
import PIL.Image
import pytest
import ms_adapter.pytorch as torch
from ms_adapter.torchvision.datasets import VisionDataset
from ms_adapter.torchvision.io import write_video
from common_utils import disable_console_output, get_tmp_dir
from ms_adapter.torchvision.transforms.functional import get_dimensions

__all__ = [
"UsageError",
"lazy_importer",
"test_all_configs",
"DatasetTestCase",
"ImageDatasetTestCase",
"VideoDatasetTestCase",
"create_image_or_video_tensor",
"create_image_file",
"create_image_folder",
"create_video_file",
"create_video_folder",
"make_tar",
"make_zip",
"create_random_string",
]


class UsageError(Exception):
"""Should be raised in case an error happens in the setup rather than the test."""


class LazyImporter:
r"""Lazy importer for additional dependencies.

Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class
provide modules listed in MODULES as attributes. They are only imported when accessed.

"""
MODULES = (
"av",
"lmdb",
"pycocotools",
"requests",
"scipy.io",
"scipy.sparse",
"h5py",
)

def __init__(self):
modules = defaultdict(list)
for module in self.MODULES:
module, *submodules = module.split(".", 1)
if submodules:
modules[module].append(submodules[0])
else:
# This introduces the module so that it is known when we later iterate over the dictionary.
modules.__missing__(module)

for module, submodules in modules.items():
# We need the quirky 'module=module' and submodules=submodules arguments to the lambda since otherwise the
# lookup for these would happen at runtime rather than at definition. Thus, without it, every property
# would try to import the last item in 'modules'
setattr(
type(self),
module,
property(lambda self, module=module, submodules=submodules: LazyImporter._import(module, submodules)),
)

@staticmethod
def _import(package, subpackages):
try:
module = importlib.import_module(package)
except ImportError as error:
raise UsageError(
f"Failed to import module '{package}'. "
f"This probably means that the current test case needs '{package}' installed, "
f"but it is not a dependency of torchvision. "
f"You need to install it manually, for example 'pip install {package}'."
) from error

for name in subpackages:
importlib.import_module(f".{name}", package=package)

return module


lazy_importer = LazyImporter()


def requires_lazy_imports(*modules):
def outer_wrapper(fn):
@functools.wraps(fn)
def inner_wrapper(*args, **kwargs):
for module in modules:
getattr(lazy_importer, module.replace(".", "_"))
return fn(*args, **kwargs)

return inner_wrapper

return outer_wrapper


def test_all_configs(test):
"""Decorator to run test against all configurations.

Add this as decorator to an arbitrary test to run it against all configurations. This includes
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.

The current configuration is provided as the first parameter for the test:

.. code-block::

@test_all_configs()
def test_foo(self, config):
pass

.. note::

This will try to remove duplicate configurations. During this process it will not not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
"""

def maybe_remove_duplicates(configs):
try:
return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out.
return configs

@functools.wraps(test)
def wrapper(self):
configs = []
if self.DEFAULT_CONFIG is not None:
configs.append(self.DEFAULT_CONFIG)
if self.ADDITIONAL_CONFIGS is not None:
configs.extend(self.ADDITIONAL_CONFIGS)

if not configs:
configs = [self._KWARG_DEFAULTS.copy()]
else:
configs = maybe_remove_duplicates(configs)

for config in configs:
with self.subTest(**config):
test(self, config)

return wrapper


def combinations_grid(**kwargs):
"""Creates a grid of input combinations.

Each element in the returned sequence is a dictionary containing one possible combination as values.

Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


class DatasetTestCase(unittest.TestCase):
"""Abstract base class for all dataset testcases.

You have to overwrite the following class attributes:

- DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested.
- FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of
providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to
get a reasonable default, that should work for most cases. Each entry of the sequence may be a tuple,
to indicate multiple possible values.

Optionally, you can overwrite the following class attributes:

- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.

Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are
disabled.

Without further configuration, the testcase will test if

1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
corrupted,
2. the dataset inherits from `torchvision.datasets.VisionDataset`,
3. the dataset can be turned into a string,
4. the feature types of a returned example matches ``FEATURE_TYPES``,
5. the number of examples matches the injected fake data, and
6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data.

Case 3. to 6. are tested against all configurations in ``CONFIGS``.

To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix:

.. code-block::

def test_foo(self):
pass

If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and
accept a single argument:

.. code-block::

@test_all_configs
def test_bar(self, config):
pass

Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional
information provided by the ``ìnject_fake_data()`` method:

.. code-block::

def test_baz(self):
with self.create_dataset() as (dataset, info):
pass
"""

DATASET_CLASS = None
FEATURE_TYPES = None

DEFAULT_CONFIG = None
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None

# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = {
"transform",
"target_transform",
"transforms",
}
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS = {
*_TRANSFORM_KWARGS,
"download",
}

# These fields are populated during setupClass() within _populate_private_class_attributes()

# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS = None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG = None

# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS = {
"check_md5",
"check_integrity",
}
_DOWNLOAD_EXTRACT_FUNCTIONS = {
"download_url",
"download_file_from_google_drive",
"extract_archive",
"download_and_extract_archive",
}

def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
"""Define positional arguments passed to the dataset.

.. note::

The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
Otherwise you need to overwrite this method.

Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.

Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
"""
return (tmpdir,)

def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
"""Inject fake data for dataset into a temporary directory.

During the creation of the dataset the download and extract logic is disabled. Thus, the fake data injected
here needs to resemble the raw data, i.e. the state of the dataset directly after the files are downloaded and
potentially extracted.

Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.

Needs to return one of the following:

1. (int): Number of examples in the dataset to be created, or
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
"""
raise NotImplementedError("You need to provide fake data in order for the tests to run.")

@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
patch_checks: Optional[bool] = None,
**kwargs: Any,
):
r"""Create the dataset in a temporary directory.

The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:

1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.

Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
omitted defaults to the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``.

Yields:
dataset (torchvision.dataset.VisionDataset): Dataset.
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
if patch_checks is None:
patch_checks = inject_fake_data

special_kwargs, other_kwargs = self._split_kwargs(kwargs)

complete_config = self._KWARG_DEFAULTS.copy()
if self.DEFAULT_CONFIG:
complete_config.update(self.DEFAULT_CONFIG)
if config:
complete_config.update(config)
if other_kwargs:
complete_config.update(other_kwargs)

if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
# override download param to False param if its default is truthy
special_kwargs["download"] = False

patchers = self._patch_download_extract()
if patch_checks:
patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)

yield dataset, info

@classmethod
def setUpClass(cls):
cls._verify_required_public_class_attributes()
cls._populate_private_class_attributes()
cls._process_optional_public_class_attributes()
super().setUpClass()

@classmethod
def _verify_required_public_class_attributes(cls):
if cls.DATASET_CLASS is None:
raise UsageError(
"The class attribute 'DATASET_CLASS' needs to be overwritten. "
"It should contain the class of the dataset to be tested."
)
if cls.FEATURE_TYPES is None:
raise UsageError(
"The class attribute 'FEATURE_TYPES' needs to be overwritten. "
"It should contain a sequence of types that the dataset returns when accessed by index."
)

@classmethod
def _populate_private_class_attributes(cls):
defaults = []
for cls_ in cls.DATASET_CLASS.__mro__:
if cls_ is VisionDataset:
break

argspec = inspect.getfullargspec(cls_.__init__)

if not argspec.defaults:
continue

defaults.append(
{
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
if not kwarg.startswith("_")
}
)

if not argspec.varkw:
break

kwarg_defaults = dict()
for config in reversed(defaults):
kwarg_defaults.update(config)

has_special_kwargs = set()
for name in cls._SPECIAL_KWARGS:
if name not in kwarg_defaults:
continue

del kwarg_defaults[name]
has_special_kwargs.add(name)

cls._KWARG_DEFAULTS = kwarg_defaults
cls._HAS_SPECIAL_KWARG = has_special_kwargs

@classmethod
def _process_optional_public_class_attributes(cls):
def check_config(config, name):
special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
if special_kwargs:
raise UsageError(
f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
f"These are handled separately by the test case and should not be set here. "
f"If you need to test some custom behavior regarding these parameters, "
f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)

if cls.DEFAULT_CONFIG is not None:
check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")

if cls.ADDITIONAL_CONFIGS is not None:
for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
check_config(config, f"CONFIGS[{idx}]")

if cls.REQUIRED_PACKAGES:
missing_pkgs = []
for pkg in cls.REQUIRED_PACKAGES:
try:
importlib.import_module(pkg)
except ImportError:
missing_pkgs.append(f"'{pkg}'")

if missing_pkgs:
raise unittest.SkipTest(
f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
f"'{cls.DATASET_CLASS.__name__}', but are not installed."
)

def _split_kwargs(self, kwargs):
special_kwargs = kwargs.copy()
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs

def _inject_fake_data(self, tmpdir, config):
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
return info

def _patch_download_extract(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}

def _patch_checks(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}

@contextlib.contextmanager
def _maybe_apply_patches(self, patchers):
with contextlib.ExitStack() as stack:
mocks = {}
for patcher in patchers:
with contextlib.suppress(AttributeError):
mocks[patcher.target] = stack.enter_context(patcher)
yield mocks

def test_not_found_or_corrupted(self):
with pytest.raises((FileNotFoundError, RuntimeError)):
with self.create_dataset(inject_fake_data=False):
pass

def test_smoke(self):
with self.create_dataset() as (dataset, _):
assert isinstance(dataset, VisionDataset)

@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
assert isinstance(str(dataset), str)

@test_all_configs
def test_feature_types(self, config):
with self.create_dataset(config) as (dataset, _):

example = dataset[0]

if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
assert (
actual == expected
), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}"
else:
example = (example,)

for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
assert isinstance(feature, expected_feature_type)

@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
assert len(dataset) == info["num_examples"]

@test_all_configs
def test_transforms(self, config):
mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args)
for kwarg in self._TRANSFORM_KWARGS:
if kwarg not in self._HAS_SPECIAL_KWARG:
continue

mock.reset_mock()

with self.subTest(kwarg=kwarg):
with self.create_dataset(config, **{kwarg: mock}) as (dataset, _):
dataset[0]

mock.assert_called()


class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.

- Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label.
"""

FEATURE_TYPES = (PIL.Image.Image, int)

@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
patch_checks: Optional[bool] = None,
**kwargs: Any,
):
with super().create_dataset(
config=config,
inject_fake_data=inject_fake_data,
patch_checks=patch_checks,
**kwargs,
) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
# to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
# force-load opened images.
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
# user wants to work with the image he just opened rather than deleting the underlying file.
with self._force_load_images():
yield dataset, info

@contextlib.contextmanager
def _force_load_images(self):
open = PIL.Image.open

def new(fp, *args, **kwargs):
image = open(fp, *args, **kwargs)
if isinstance(fp, (str, pathlib.Path)):
image.load()
return image

with unittest.mock.patch("PIL.Image.open", new=new):
yield


class VideoDatasetTestCase(DatasetTestCase):
"""Abstract base class for video dataset testcases.

- Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as
well as an integer label.
- Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``).
- Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()'
and it is the last parameter without a default value in the dataset constructor, the value of the
'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output.
"""

FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
REQUIRED_PACKAGES = ("av",)

DEFAULT_FRAMES_PER_CLIP = 1

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)

def _set_default_frames_per_clip(self, inject_fake_data):
argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
frames_per_clip_last = args_without_default[-1] == "frames_per_clip"

@functools.wraps(inject_fake_data)
def wrapper(tmpdir, config):
args = inject_fake_data(tmpdir, config)
if frames_per_clip_last and len(args) == len(args_without_default) - 1:
args = (*args, self.DEFAULT_FRAMES_PER_CLIP)

return args

return wrapper


def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.

Args:
size (Sequence[int]): Size of the tensor.
"""
return torch.randint(0, 256, size, dtype=torch.uint8)


def create_image_file(
root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any
) -> pathlib.Path:
"""Create an image file from random data.

Args:
root (Union[str, pathlib.Path]): Root directory the image file will be placed in.
name (Union[str, pathlib.Path]): Name of the image file.
size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If
scalar, the value is used for the height and width. If not provided, three channels are assumed.
kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`.

Returns:
pathlib.Path: Path to the created image file.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) != 3:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead"
)

image = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name

# torch (num_channels x height x width) -> PIL (width x height x num_channels)
image = image.permute(2, 1, 0)
# For grayscale images PIL doesn't use a channel dimension
if image.shape[2] == 1:
image = torch.squeeze(image, 2)
PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
return file


def create_image_folder(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
**kwargs: Any,
) -> List[pathlib.Path]:
"""Create a folder of random images.

Args:
root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
name (Union[str, pathlib.Path]): Name of the image folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of images to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If
callable, will be called with the index of the corresponding file. If omitted, a random height and width
between 3 and 10 pixels is selected on a per-image basis.
kwargs (Any): Additional parameters passed to :func:`create_image_file`.

Returns:
List[pathlib.Path]: Paths to all created image files.

.. seealso::

- :func:`create_image_file`
"""
if size is None:

def size(idx: int) -> Tuple[int, int, int]:
num_channels = 3
height, width = torch.randint(3, 11, size=(2,), dtype=torch.int32).tolist()
return (num_channels, height, width)

root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)

return [
create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]


def shape_test_for_stereo(
left: PIL.Image.Image,
right: PIL.Image.Image,
disparity: Optional[np.ndarray] = None,
valid_mask: Optional[np.ndarray] = None,
):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
assert c == 3

# check that the disparity has the same spatial dimensions
# as the input
if disparity is not None:
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)

if valid_mask is not None:
# check that valid mask is the same size as the disparity
_, dh, dw = disparity.shape
mh, mw = valid_mask.shape
assert dh == mh
assert dw == mw


@requires_lazy_imports("av")
def create_video_file(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
size: Union[Sequence[int], int] = (1, 3, 10, 10),
fps: float = 25,
**kwargs: Any,
) -> pathlib.Path:
"""Create an video file from random data.

Args:
root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
name (Union[str, pathlib.Path]): Name of the video file.
size (Union[Sequence[int], int]): Size of the video that represents the
``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width.
If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`.

Returns:
pathlib.Path: Path to the created image file.

Raises:
UsageError: If PyAV is not available.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) == 3:
size = (1, *size)
if len(size) != 4:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead"
)

video = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs)
return file


@requires_lazy_imports("av")
def create_video_folder(
root: Union[str, pathlib.Path],
name: Union[str, pathlib.Path],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
fps=25,
**kwargs,
) -> List[pathlib.Path]:
"""Create a folder of random videos.

Args:
root (Union[str, pathlib.Path]): Root directory the video folder will be placed in.
name (Union[str, pathlib.Path]): Name of the video folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of videos to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If
callable, will be called with the index of the corresponding file. If omitted, a random even height and
width between 4 and 10 pixels is selected on a per-video basis.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`create_video_file`.

Returns:
List[pathlib.Path]: Paths to all created video files.

Raises:
UsageError: If PyAV is not available.

.. seealso::

- :func:`create_video_file`
"""
if size is None:

def size(idx):
num_frames = 1
num_channels = 3
# The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and
# width of the video to be divisible by 2.
height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int32) * 2).tolist()
return (num_frames, num_channels, height, width)

root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)

return [
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]


def _split_files_or_dirs(root, *files_or_dirs):
files = set()
dirs = set()
for file_or_dir in files_or_dirs:
path = pathlib.Path(file_or_dir)
if not path.is_absolute():
path = root / path
if path.is_file():
files.add(path)
else:
dirs.add(path)
for sub_file_or_dir in path.glob("**/*"):
if sub_file_or_dir.is_file():
files.add(sub_file_or_dir)
else:
dirs.add(sub_file_or_dir)

if root in dirs:
dirs.remove(root)

return files, dirs


def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
archive = pathlib.Path(root) / name
if not files_or_dirs:
# We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
# present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
file_or_dir = archive
for _ in range(len(archive.suffixes)):
file_or_dir = file_or_dir.with_suffix("")
if file_or_dir.exists():
files_or_dirs = (file_or_dir,)
else:
raise ValueError("No file or dir provided.")

files, dirs = _split_files_or_dirs(root, *files_or_dirs)

with opener(archive) as fh:
for file in sorted(files):
adder(fh, file, file.relative_to(root))

if remove:
for file in files:
os.remove(file)
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)

return archive


def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
# TODO: detect compression from name
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
remove=remove,
)


def make_zip(root, name, *files_or_dirs, remove=True):
return _make_archive(
root,
name,
*files_or_dirs,
opener=lambda archive: zipfile.ZipFile(archive, "w"),
adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
remove=remove,
)


def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.

Args:
length (int): Number of characters in the generated string.
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
"""
if not digits:
digits = string.ascii_lowercase
else:
digits = "".join(itertools.chain(*digits))

return "".join(random.choice(digits) for _ in range(length))


def make_fake_pfm_file(h, w, file_name):
values = list(range(3 * h * w))
# Note: we pack everything in little endian: -1.0, and "<"
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)


def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
# Everything needs to be in little Endian according to
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
values = list(range(2 * h * w))
content = (
struct.pack("<4c", *(c.encode() for c in "PIEH"))
+ struct.pack("<i", w)
+ struct.pack("<i", h)
+ struct.pack("<" + "f" * len(values), *values)
)
with open(file_name, "wb") as f:
f.write(content)

+ 2654
- 0
testing/ut/torchvision/test_datasets.py
File diff suppressed because it is too large
View File


+ 87
- 0
testing/ut/torchvision/test_datasets_samplers.py View File

@@ -0,0 +1,87 @@
import numpy as np
import pytest
import ms_adapter.pytorch as torch
from common_utils import get_list_of_videos
from ms_adapter.torchvision import io
from ms_adapter.torchvision.datasets.samplers import DistributedSampler, RandomClipSampler, UniformClipSampler
from ms_adapter.torchvision.datasets.video_utils import VideoClips


@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
def test_random_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode="floor").numpy()
v_idxs, count = np.unique(videos, return_counts=True)
assert np.allclose(v_idxs, torch.tensor([0, 1, 2]).numpy())
assert np.allclose(count, torch.tensor([3, 3, 3]).numpy())

def test_random_clip_sampler_unequal(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler))
assert 0 in indices
assert 1 in indices
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode="floor").numpy()
v_idxs, count = np.unique(videos, return_counts=True)
assert np.allclose(v_idxs, torch.tensor([0, 1]).numpy())
assert np.allclose(count, torch.tensor([3, 3]).numpy())

def test_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode="floor").numpy()
v_idxs, count = np.unique(videos, return_counts=True)
assert np.allclose(v_idxs, torch.tensor([0, 1, 2]).numpy())
assert np.allclose(count, torch.tensor([3, 3, 3]).numpy())
assert np.allclose(indices.numpy(), torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]).numpy())

def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert np.allclose(indices.numpy(), torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]).numpy())

def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)

distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
assert len(distributed_sampler_rank0) == 6
assert np.allclose(indices.numpy(), torch.tensor([0, 2, 4, 10, 12, 14]).numpy())

distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
assert len(distributed_sampler_rank1) == 6
assert np.allclose(indices.numpy(), torch.tensor([5, 7, 9, 0, 2, 4]).numpy())


if __name__ == "__main__":
pytest.main([__file__])

+ 246
- 0
testing/ut/torchvision/test_datasets_utils.py View File

@@ -0,0 +1,246 @@
import contextlib
import gzip
import os
import pathlib
import re
import tarfile
import zipfile

import pytest
import ms_adapter.torchvision.datasets.utils as utils
from ms_adapter.torchvision.datasets.folder import make_dataset
from ms_adapter.torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS

def get_file_path_2(*path_components: str) -> str:
return os.path.join(*path_components)

TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)


def patch_url_redirection(mocker, redirect_url):
class Response:
def __init__(self, url):
self.url = url

@contextlib.contextmanager
def patched_opener(*args, **kwargs):
yield Response(redirect_url)

return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


class TestDatasetsUtils:
# def test_get_redirect_url(self, mocker): #TODO pytorch error, too
# url = "https://url.org"
# expected_redirect_url = "https://redirect.url.org"
#
# mock = patch_url_redirection(mocker, expected_redirect_url)
#
# actual = utils._get_redirect_url(url)
# assert actual == expected_redirect_url
#
# assert mock.call_count == 2
# call_args_1, call_args_2 = mock.call_args_list
# assert call_args_1[0][0].full_url == url
# assert call_args_2[0][0].full_url == expected_redirect_url
#
# def test_get_redirect_url_max_hops_exceeded(self, mocker):
# url = "https://url.org"
# redirect_url = "https://redirect.url.org"
#
# mock = patch_url_redirection(mocker, redirect_url)
#
# with pytest.raises(RecursionError):
# utils._get_redirect_url(url, max_hops=0)
#
# assert mock.call_count == 1
# assert mock.call_args[0][0].full_url == url

def test_check_md5(self):
fpath = TEST_FILE
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_md5(fpath, correct_md5)
assert not utils.check_md5(fpath, false_md5)

def test_check_integrity(self):
existing_fpath = TEST_FILE
nonexisting_fpath = ""
correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
false_md5 = ""
assert utils.check_integrity(existing_fpath, correct_md5)
assert not utils.check_integrity(existing_fpath, false_md5)
assert utils.check_integrity(existing_fpath)
assert not utils.check_integrity(nonexisting_fpath)

def test_get_google_drive_file_id(self):
url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"

actual = utils._get_google_drive_file_id(url)
assert actual == expected

def test_get_google_drive_file_id_invalid_url(self):
url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

assert utils._get_google_drive_file_id(url) is None

@pytest.mark.parametrize(
"file, expected",
[
("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
("foo.tar", (".tar", ".tar", None)),
("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.tbz", (".tbz", ".tar", ".bz2")),
("foo.tbz2", (".tbz2", ".tar", ".bz2")),
("foo.tgz", (".tgz", ".tar", ".gz")),
("foo.bz2", (".bz2", None, ".bz2")),
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
],
)
def test_detect_file_type(self, file, expected):
assert utils._detect_file_type(file) == expected

@pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
def test_detect_file_type_incompatible(self, file):
# tests detect file type for no extension, unknown compression and unknown partial extension
with pytest.raises(RuntimeError):
utils._detect_file_type(file)

@pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]

with compressed_file_opener(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

compressed, file, content = create_compressed(tmpdir)

utils._decompress(compressed)

assert os.path.exists(file)

with open(file) as fh:
assert fh.read() == content

def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"

with gzip.open(compressed, "wb") as fh:
fh.write(content.encode())

return compressed, file, content

compressed, file, content = create_compressed(tmpdir)

utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)

# @pytest.mark.parametrize("extension", [".gz", ".xz"])
# @pytest.mark.parametrize("remove_finished", [True, False])
# def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
# filename = "foo"
# file = f"{filename}{extension}"
#
# mocked = mocker.patch("torchvision.datasets.utils._decompress")
# utils.extract_archive(file, remove_finished=remove_finished)
#
# mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self, tmpdir):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")

with zipfile.ZipFile(archive, "w") as zf:
zf.writestr(os.path.basename(file), content)

return archive, file, content

archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)

with open(file) as fh:
assert fh.read() == content

@pytest.mark.parametrize(
"extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
)
def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
archive = os.path.join(root, f"archive{extension}")

with open(src, "w") as fh:
fh.write(content)

with tarfile.open(archive, mode=mode) as fh:
fh.add(src, arcname=os.path.basename(dst))

return archive, dst, content

archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)

with open(file) as fh:
assert fh.read() == content

def test_verify_str_arg(self):
assert "a" == utils.verify_str_arg("a", "arg", ("a",))
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")


@pytest.mark.parametrize(
("kwargs", "expected_error_msg"),
[
(dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
(dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
(dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
tmpdir = pathlib.Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "a" / "a.png").touch()

(tmpdir / "b").mkdir()
(tmpdir / "b" / "b.jpeg").touch()

(tmpdir / "c").mkdir()
(tmpdir / "c" / "c.unknown").touch()

with pytest.raises(FileNotFoundError, match=expected_error_msg):
make_dataset(str(tmpdir), **kwargs)


if __name__ == "__main__":
pytest.main([__file__])

+ 105
- 0
testing/ut/torchvision/test_datasets_video_utils.py View File

@@ -0,0 +1,105 @@
import pytest
import numpy as np
import ms_adapter.pytorch as torch
from common_utils import get_list_of_videos
from ms_adapter.torchvision import io
from ms_adapter.torchvision.datasets.video_utils import unfold, VideoClips


class TestVideo:
def test_unfold(self):
a = torch.tensor([0,1,2,3,4,5,6])
r = unfold(a, 3, 3, 1)
expected = torch.tensor(
[
[0, 1, 2],
[3, 4, 5],
]
)
assert np.allclose(r.numpy(), expected.numpy())

r = unfold(a, 3, 2, 1)
expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]])
assert np.allclose(r.numpy(), expected.numpy())

r = unfold(a, 3, 2, 2)
expected = torch.tensor(
[
[0, 2, 4],
[2, 4, 6],
]
)
assert np.allclose(r.numpy(), expected.numpy())

@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3)
video_clips = VideoClips(video_list, 5, 5, num_workers=2)
assert video_clips.num_clips() == 1 + 2 + 3
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
assert video_idx == v_idx
assert clip_idx == c_idx

video_clips = VideoClips(video_list, 6, 6)
assert video_clips.num_clips() == 0 + 1 + 2
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
assert video_idx == v_idx
assert clip_idx == c_idx

video_clips = VideoClips(video_list, 6, 1)
assert video_clips.num_clips() == 0 + (10 - 6 + 1) + (15 - 6 + 1)
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
assert video_idx == v_idx
assert clip_idx == c_idx

@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips_custom_fps(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6])
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps, num_workers=2)
for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i)
assert video.shape[0] == num_frames
assert info["video_fps"] == fps
# TODO add tests checking that the content is right

def test_compute_clips_for_video(self):
video_pts = torch.tensor(np.arange(30))
# case 1: single clip
num_frames = 13
orig_fps = 30
duration = float(len(video_pts)) / orig_fps
new_fps = 13
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
assert len(clips) == 1
assert np.allclose(clips.numpy(), idxs.numpy())
assert np.allclose(idxs[0].numpy(), resampled_idxs.numpy())

# case 2: all frames appear only once
num_frames = 4
orig_fps = 30
duration = float(len(video_pts)) / orig_fps
new_fps = 12
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
assert len(clips) == 3
assert np.allclose(clips.numpy(), idxs.numpy())
assert np.allclose(idxs.flatten().numpy(), resampled_idxs.numpy())

# case 3: frames aren't enough for a clip
num_frames = 32
orig_fps = 30
new_fps = 13
with pytest.warns(UserWarning):
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps)
assert len(clips) == 0
assert len(idxs) == 0


if __name__ == "__main__":
pytest.main([__file__])

Loading…
Cancel
Save