Browse Source

rank=0, group_size=1

master
liupengfei 1 year ago
parent
commit
789566f9d3
1 changed files with 6 additions and 7 deletions
  1. +6
    -7
      train.py

+ 6
- 7
train.py View File

@@ -47,11 +47,8 @@ import argparse
def set_save_ckpt_dir():
"""set save ckpt dir"""
ckpt_save_dir = config.save_checkpoint_path
if config.enable_modelarts and config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank_id()) + "/"
else:
if config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/"
if config.run_distribute:
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/"
return ckpt_save_dir


@@ -155,7 +152,9 @@ def run_train():
config.device_id = get_device_id()
# TODO lpf rank和 group_size 的设置是否有问题
rank = get_rank()
group_size = get_group_size()
group_size = get_group_size()
print('rank ccc =',rank)
print('group_size ccc =',group_size)
# ms.set_context(device_id=config.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
@@ -168,7 +167,7 @@ def run_train():
print('config.train_data_path =',config.train_data_path )
flownet_train_gen = config.training_dataset_class(config.crop_type, config.crop_size, config.eval_size,
config.train_data_path)
sampler = datasets.DistributedSampler(flownet_train_gen, rank=rank, group_size=group_size, shuffle=True)
sampler = datasets.DistributedSampler(flownet_train_gen, rank=0, group_size=1, shuffle=True)
print('sampler')
train_dataset = ds.GeneratorDataset(flownet_train_gen, ["images", "flow"],
sampler=sampler, num_parallel_workers=config.num_parallel_workers)


Loading…
Cancel
Save