|
|
@@ -1,6 +1,6 @@ |
|
|
|
import torch |
|
|
|
import pickle as pkl |
|
|
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler |
|
|
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, distributed |
|
|
|
from torchvision import transforms |
|
|
|
import cv2 |
|
|
|
|
|
|
@@ -105,13 +105,16 @@ def get_dataset(params): |
|
|
|
train_dataset = HYBTr_Dataset(params, params['train_image_path'], params['train_label_path'], words) |
|
|
|
eval_dataset = HYBTr_Dataset(params, params['eval_image_path'], params['eval_label_path'], words) |
|
|
|
|
|
|
|
train_sampler = RandomSampler(train_dataset) |
|
|
|
eval_sampler = RandomSampler(eval_dataset) |
|
|
|
train_sampler = distributed.DistributedSampler(train_dataset) |
|
|
|
eval_sampler = distributed.DistributedSampler(eval_dataset) |
|
|
|
|
|
|
|
train_sampler = RandomSampler(train_sampler) |
|
|
|
eval_sampler = RandomSampler(eval_sampler) |
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], sampler=train_sampler, |
|
|
|
num_workers=params['workers'], collate_fn=train_dataset.collate_fn, pin_memory=False) |
|
|
|
num_workers=2, collate_fn=train_dataset.collate_fn, pin_memory=False) |
|
|
|
eval_loader = DataLoader(eval_dataset, batch_size=1, sampler=eval_sampler, |
|
|
|
num_workers=params['workers'], collate_fn=eval_dataset.collate_fn, pin_memory=False) |
|
|
|
num_workers=2, collate_fn=eval_dataset.collate_fn, pin_memory=False) |
|
|
|
|
|
|
|
print(f'train dataset: {len(train_dataset)} train steps: {len(train_loader)} ' |
|
|
|
f'eval dataset: {len(eval_dataset)} eval steps: {len(eval_loader)}') |
|
|
|