|
|
@@ -174,7 +174,8 @@ def run_train(): |
|
|
|
sampler = datasets.DistributedSampler(flownet_train_gen, rank=rank, group_size=group_size, shuffle=True) |
|
|
|
print('sampler') |
|
|
|
train_dataset = ds.GeneratorDataset(flownet_train_gen, ["images", "flow"], |
|
|
|
sampler=sampler, num_parallel_workers=config.num_parallel_workers) |
|
|
|
sampler=sampler, num_parallel_workers=config.num_parallel_workers |
|
|
|
,num_shards=device_num, shard_id=rank) |
|
|
|
train_dataset = train_dataset.batch(config.batch_size, num_parallel_workers=config.num_parallel_workers) |
|
|
|
step_size = train_dataset.get_dataset_size() |
|
|
|
print("Step size: ", step_size,flush=True) |
|
|
|