|
- """
- Create dataloader
- """
-
- import warnings
-
- import mindspore as ms
- from mindspore.dataset import transforms#, vision
-
- from .transforms_factory import create_transforms
- from .mixup import Mixup
-
- __all__ = ["create_loader"]
-
-
- def create_loader(
- dataset,
- batch_size,
- drop_remainder=False,
- is_training=False,
- mixup=0.0,
- cutmix=0.0,
- cutmix_prob=0.0,
- num_classes=1000,
- transform=None,
- target_transform=None,
- num_parallel_workers=None,
- python_multiprocessing=False,
- ):
- r"""Creates dataloader.
-
- Applies operations such as transform and batch to the `ms.dataset.Dataset` object
- created by the `create_dataset` function to get the dataloader.
-
- Args:
- dataset (ms.dataset.Dataset): dataset object created by `create_dataset`.
- batch_size (int or function): The number of rows each batch is created with. An
- int or callable object which takes exactly 1 parameter, BatchInfo.
- drop_remainder (bool, optional): Determines whether to drop the last block
- whose data row number is less than batch size (default=False). If True, and if there are less
- than batch_size rows available to make the last batch, then those rows will
- be dropped and not propagated to the child node.
- is_training (bool): whether it is in train mode. Default: False.
- mixup (float): mixup alpha, mixup will be enbled if > 0. (default=0.0).
- cutmix (float): cutmix alpha, cutmix will be enabled if > 0. (default=0.0). This operation is experimental.
- cutmix_prob (float): prob of doing cutmix for an image (default=0.0)
- num_classes (int): the number of classes. Default: 1000.
- transform (list or None): the list of transformations that wil be applied on the image,
- which is obtained by `create_transform`. If None, the default imagenet transformation
- for evaluation will be applied. Default: None.
- target_transform (list or None): the list of transformations that will be applied on the label.
- If None, the label will be converted to the type of ms.int32. Default: None.
- num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel
- (default=None).
- python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker processes. This
- option could be beneficial if the Python operation is computational heavy (default=False).
-
- Note:
- 1. cutmix is now experimental (which means performance gain is not guarantee) and can not be used together with mixup due to the label int type conflict.
- 2. `is_training`, `mixup`, `num_classes` is used for MixUp, which is a kind of transform operation.
- However, we are not able to merge it into `transform`, due to the limitations of the `mindspore.dataset` API.
-
-
- Returns:
- BatchDataset, dataset batched.
- """
-
- if transform is None:
- warnings.warn("Using None as the default value of transform will set it back to "
- "traditional image transform, which is not recommended. "
- "You should explicitly call `create_transforms` and pass it to `create_loader`.")
- transform = create_transforms("imagenet", is_training=False)
- dataset = dataset.map(operations=transform,
- input_columns='image',
- num_parallel_workers=num_parallel_workers,
- python_multiprocessing=python_multiprocessing)
-
- if target_transform is None:
- target_transform = transforms.TypeCast(ms.int32)
- is_onehot_target = False
- else:
- is_onehot_target = True
-
-
- target_input_columns = 'label' if 'label' in dataset.get_col_names() else 'fine_label'
- dataset = dataset.map(operations=target_transform,
- input_columns=target_input_columns,
- num_parallel_workers=num_parallel_workers,
- python_multiprocessing=python_multiprocessing)
-
- dataset = dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
-
- #assert (mixup * cutmix == 0), 'Currently, mixup and cutmix cannot be applied together'
-
- if is_training:
- trans_batch = []
- if (mixup + cutmix > 0.0) and batch_size > 1:
- #TODO: use mindspore vision cutmix and mixup after the confliction fixed in later release
- # set label_smoothing 0 here since label smoothing is computed in loss module
- mixup_fn = Mixup(
- mixup_alpha=mixup,
- cutmix_alpha=cutmix,
- cutmix_minmax=None,
- prob=cutmix_prob,
- switch_prob=0.5,
- label_smoothing=0.0,
- num_classes=num_classes,
- is_onehot_label=is_onehot_target)
- trans_batch = mixup_fn
- #trans_batch = vision.MixUpBatch(alpha=mixup)
-
- if trans_batch != []:
- # images in a batch are mixed. labels are converted soft onehot labels.
- dataset = dataset.map(input_columns=["image", target_input_columns],
- num_parallel_workers=num_parallel_workers, operations=trans_batch)
-
- return dataset
|