|
- import os
- import sys
-
- sys.path.append(".")
-
- import pytest
-
- import mindspore as ms
-
- from mindcv.data import create_dataset, get_dataset_download_root
- from mindcv.utils.download import DownLoad
-
-
- # test imagenet
- @pytest.mark.parametrize("mode", [0, 1])
- @pytest.mark.parametrize("name", ["ImageNet"])
- @pytest.mark.parametrize("split", ["train", "val"])
- @pytest.mark.parametrize("shuffle", [True, False])
- @pytest.mark.parametrize("num_samples", [2, None])
- @pytest.mark.parametrize("num_parallel_workers", [2])
- def test_create_dataset_standalone_imagenet(mode, name, split, shuffle, num_samples, num_parallel_workers):
- """
- test create_dataset API(standalone)
- command: pytest -s test_dataset.py::test_create_dataset_standalone_imagenet
-
- API Args:
- name: str = '',
- root: str = './',
- split: str = 'train',
- shuffle: bool = True,
- num_samples: Optional[bool] = None,
- num_shards: Optional[int] = None,
- shard_id: Optional[int] = None,
- num_parallel_workers: Optional[int] = None,
- download: bool = False,
- **kwargs
- """
-
- ms.set_context(mode=mode)
- dataset_url = (
- "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
- )
- root_dir = os.path.join(get_dataset_download_root(), "Canidae")
- data_dir = os.path.join(root_dir, "data", "Canidae") # Canidae has prefix path "data/Canidae" in unzipped file.
- if not os.path.exists(data_dir):
- DownLoad().download_and_extract_archive(dataset_url, root_dir)
- dataset = create_dataset(
- name=name,
- root=data_dir,
- split=split,
- shuffle=shuffle,
- num_samples=num_samples,
- num_parallel_workers=num_parallel_workers,
- download=False,
- )
-
- assert type(dataset) == ms.dataset.engine.datasets_vision.ImageFolderDataset
- assert dataset is not None
-
-
- # test MNIST CIFAR10
- @pytest.mark.parametrize("mode", [0, 1])
- @pytest.mark.parametrize("name", ["MNIST", "CIFAR10"])
- @pytest.mark.parametrize("split", ["train", "test"])
- @pytest.mark.parametrize("shuffle", [True, False])
- @pytest.mark.parametrize("num_samples", [2, None])
- @pytest.mark.parametrize("num_parallel_workers", [2])
- @pytest.mark.parametrize("download", [True])
- def test_create_dataset_standalone_mc(mode, name, split, shuffle, num_samples, num_parallel_workers, download):
- """
- test create_dataset API(standalone)
- command: pytest -s test_dataset.py::test_create_dataset_standalone_mc
-
- API Args:
- name: str = '',
- root: str = './',
- split: str = 'train',
- shuffle: bool = True,
- num_samples: Optional[bool] = None,
- num_shards: Optional[int] = None,
- shard_id: Optional[int] = None,
- num_parallel_workers: Optional[int] = None,
- download: bool = False,
- **kwargs
- """
-
- ms.set_context(mode=mode)
-
- dataset = create_dataset(
- name=name,
- split=split,
- shuffle=shuffle,
- num_samples=num_samples,
- num_parallel_workers=num_parallel_workers,
- download=download,
- )
-
- assert (
- type(dataset) == ms.dataset.engine.datasets_vision.MnistDataset
- or type(dataset) == ms.dataset.engine.datasets_vision.Cifar10Dataset
- )
- assert dataset is not None
|