Browse Source

修改一下切片

master
liupengfei 1 year ago
parent
commit
36babaf176
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      train.py

+ 2
- 1
train.py View File

@@ -174,7 +174,8 @@ def run_train():
sampler = datasets.DistributedSampler(flownet_train_gen, rank=rank, group_size=group_size, shuffle=True)
print('sampler')
train_dataset = ds.GeneratorDataset(flownet_train_gen, ["images", "flow"],
sampler=sampler, num_parallel_workers=config.num_parallel_workers)
sampler=sampler, num_parallel_workers=config.num_parallel_workers
,num_shards=device_num, shard_id=rank)
train_dataset = train_dataset.batch(config.batch_size, num_parallel_workers=config.num_parallel_workers)
step_size = train_dataset.get_dataset_size()
print("Step size: ", step_size,flush=True)


Loading…
Cancel
Save