|
- from .transforms import make_transforms
- from . import samplers
- import torch
- import torch.utils.data
- import imp
- import os
- from .collate_batch import make_collator
- import numpy as np
- import time
- from lib.config.config import cfg
-
-
- def _dataset_factory(is_train):
- if is_train:
- module = cfg.train_dataset_module
- path = cfg.train_dataset_path
- args = cfg.train_dataset
- else:
- module = cfg.test_dataset_module
- path = cfg.test_dataset_path
- args = cfg.test_dataset
- dataset = imp.load_source(module, path).Dataset(**args)
- return dataset
-
-
- def make_dataset(cfg, dataset_name, transforms, is_train=True):
- dataset = _dataset_factory(is_train)
- return dataset
-
-
- def make_data_sampler(dataset, shuffle, is_distributed, is_train):
- if not is_train and cfg.test.sampler == 'FrameSampler':
- sampler = samplers.FrameSampler(dataset)
- return sampler
- if is_distributed:
- return samplers.DistributedSampler(dataset, shuffle=shuffle)
- if shuffle:
- sampler = torch.utils.data.sampler.RandomSampler(dataset)
- else:
- sampler = torch.utils.data.sampler.SequentialSampler(dataset)
- return sampler
-
-
- def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter,
- is_train):
- if is_train:
- batch_sampler = cfg.train.batch_sampler
- sampler_meta = cfg.train.sampler_meta
- else:
- batch_sampler = cfg.test.batch_sampler
- sampler_meta = cfg.test.sampler_meta
-
- if batch_sampler == 'default':
- batch_sampler = torch.utils.data.sampler.BatchSampler(
- sampler, batch_size, drop_last)
- elif batch_sampler == 'image_size':
- batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size,
- drop_last, sampler_meta)
-
- if max_iter != -1:
- batch_sampler = samplers.IterationBasedBatchSampler(
- batch_sampler, max_iter)
- return batch_sampler
-
-
- def worker_init_fn(worker_id):
- np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16))))
-
-
- def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):
- if is_train:
- batch_size = cfg.train.batch_size
- # shuffle = True
- shuffle = cfg.train.shuffle
- drop_last = False
- else:
- batch_size = cfg.test.batch_size
- shuffle = True if is_distributed else False
- drop_last = False
-
- dataset_name = cfg.train.dataset if is_train else cfg.test.dataset
-
- transforms = make_transforms(cfg, is_train)
- dataset = make_dataset(cfg, dataset_name, transforms, is_train)
- # sampler = make_data_sampler(dataset, shuffle, is_distributed, is_train)
- # batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size,
- # drop_last, max_iter, is_train)
- num_workers = cfg.train.num_workers
- collator = make_collator(cfg, is_train)
- data_loader = torch.utils.data.DataLoader(dataset,
- # batch_sampler=batch_sampler,
- num_workers=num_workers,
- collate_fn=collator,
- worker_init_fn=worker_init_fn)
-
- return data_loader
|