|
|
@@ -47,11 +47,8 @@ import argparse |
|
|
|
def set_save_ckpt_dir(): |
|
|
|
"""set save ckpt dir""" |
|
|
|
ckpt_save_dir = config.save_checkpoint_path |
|
|
|
if config.enable_modelarts and config.run_distribute: |
|
|
|
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank_id()) + "/" |
|
|
|
else: |
|
|
|
if config.run_distribute: |
|
|
|
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/" |
|
|
|
if config.run_distribute: |
|
|
|
ckpt_save_dir = ckpt_save_dir + "/ckpt_" + str(get_rank()) + "/" |
|
|
|
return ckpt_save_dir |
|
|
|
|
|
|
|
|
|
|
@@ -155,7 +152,9 @@ def run_train(): |
|
|
|
config.device_id = get_device_id() |
|
|
|
# TODO lpf rank和 group_size 的设置是否有问题 |
|
|
|
rank = get_rank() |
|
|
|
group_size = get_group_size() |
|
|
|
group_size = get_group_size() |
|
|
|
print('rank ccc =',rank) |
|
|
|
print('group_size ccc =',group_size) |
|
|
|
# ms.set_context(device_id=config.device_id) |
|
|
|
if device_num > 1: |
|
|
|
context.reset_auto_parallel_context() |
|
|
@@ -168,7 +167,7 @@ def run_train(): |
|
|
|
print('config.train_data_path =',config.train_data_path ) |
|
|
|
flownet_train_gen = config.training_dataset_class(config.crop_type, config.crop_size, config.eval_size, |
|
|
|
config.train_data_path) |
|
|
|
sampler = datasets.DistributedSampler(flownet_train_gen, rank=rank, group_size=group_size, shuffle=True) |
|
|
|
sampler = datasets.DistributedSampler(flownet_train_gen, rank=0, group_size=1, shuffle=True) |
|
|
|
print('sampler') |
|
|
|
train_dataset = ds.GeneratorDataset(flownet_train_gen, ["images", "flow"], |
|
|
|
sampler=sampler, num_parallel_workers=config.num_parallel_workers) |
|
|
|