#621 add ddp

Merged
zoulq merged 1 commits from wtcheng/MSAdapter:master into master 9 months ago
  1. +8
    -0
      ConstraintList.md
  2. +8
    -0
      ConstraintList_en.md
  3. +19
    -0
      SupportedList.md
  4. +21
    -0
      SupportedList_en.md
  5. +18
    -0
      msadapter/pytorch/distributed/__init__.py
  6. +334
    -0
      msadapter/pytorch/distributed/distributed_c10d.py
  7. +1
    -0
      msadapter/pytorch/nn/__init__.py
  8. +5
    -0
      msadapter/pytorch/nn/parallel/__init__.py
  9. +80
    -0
      msadapter/pytorch/nn/parallel/distributed.py
  10. +0
    -0
      testing/ut/pytorch/distributed/__init__.py
  11. +49
    -0
      testing/ut/pytorch/distributed/ddp_impl.py
  12. +16
    -0
      testing/ut/pytorch/distributed/test_ddp.py
  13. +36
    -0
      testing/ut/pytorch/nn/dist_impl.py
  14. +16
    -0
      testing/ut/pytorch/nn/test_dist.py

+ 8
- 0
ConstraintList.md View File

@@ -6,6 +6,7 @@
- [nn.functional](#jump5)
- [torch.linalg](#jump6)
- [torch.utils.data](#jump7)
- [torch.distributed](#jump8)

## <span id="jump1">接口约束列表</span>

@@ -246,6 +247,7 @@
| nn.RNN | 在图模式下,`input`不支持PackedSequence类型 |
| nn.GRU | 在图模式下,`input`不支持PackedSequence类型 |
| nn.CrossEntropyLoss | `target`类型为int64时,有溢出风险 |
| nn.parallel.DistributedDataParallel | 不支持`device_ids`, `output_device`, `dim`, `find_unused_parameters`, `check_reduction`, `gradient_as_bucket_view`, `static_graph` |

### <span id="jump5">nn.functional</span>
| MSAdapter接口 | 约束条件 |
@@ -310,3 +312,9 @@
| RandomSampler | 暂不支持传入Generator|
| SubsetRandomSampler | 暂不支持传入Generator|
| WeightedRandomSampler | 暂不支持传入Generator|

### <span id="jump8">torch.distributed</span>
| MSAdapter接口 | 约束条件 |
| --------------- |-----------------------------------------------------------------------------------------|
| init_process_group | 不支持`timeout`, `rank`, `store`, `group_name`, `pg_options`,部分支持`init_method`:以环境变量模式配置初始化 |
| new_group | 不支持`timeout`, `pg_options`,`backend`部分支持(nccl) |
panshaowu marked this conversation as resolved
Erpim commented 9 months ago
Review
pytorch 是不是不感知hccl
Erpim commented 9 months ago
Review
可以在整个torch.distributed 最上面描述下策略,比如Ascend上就默认nccl底层走hccl,或者写到通用限制里

+ 8
- 0
ConstraintList_en.md View File

@@ -7,6 +7,7 @@ English | [简体中文](ConstraintList.md)
- [nn.functional](#jump5)
- [torch.linalg](#jump6)
- [torch.utils.data](#jump7)
- [torch.distributed](#jump8)

## <span id="jump1">API Constraints List</span>

@@ -247,6 +248,7 @@ English | [简体中文](ConstraintList.md)
| nn.RNN | Under GRAPH mode, `input` not support PackedSequence type |
| nn.GRU | Under GRAPH mode, `input` not support PackedSequence type |
| nn.CrossEntropyLoss | There is risk of overflow when `target` type is int64 |
| nn.parallel.DistributedDataParallel | `device_ids`, `output_device`, `dim`, `find_unused_parameters`, `check_reduction`, `gradient_as_bucket_view`, `static_graph` are not supported|

### <span id="jump5">nn.functional</span>
| MSAdapter APIs | Constraint conditions |
@@ -311,3 +313,9 @@ English | [简体中文](ConstraintList.md)
| RandomSampler | Currently not support input Generator |
| SubsetRandomSampler | Currently not support input Generator |
| WeightedRandomSampler | Currently not support input Generator |

### <span id="jump8">torch.distributed</span>
| MSAdapter APIs | Constraint conditions |
| --------------- |-----------------------------------------------------------------------------------------|
| init_process_group | `timeout`, `rank`, `store`, `group_name`, `pg_options` are not supported, `init_method` is partly supported: initialization can be configured only in environment variable mode. |
| new_group | `timeout`, `pg_options` are not supported, `backend` is partly supported (nccl) |
panshaowu marked this conversation as resolved
Erpim commented 9 months ago
Review
同上

+ 19
- 0
SupportedList.md View File

@@ -9,6 +9,7 @@
- [torch.linalg](#jump6)
- [torch.optim](#jump7)
- [torch.utils.data](#jump9)
- [torch.distributed](#jump10)


### <span id="jump8">通用限制</span>
@@ -1022,6 +1023,7 @@
| nn.init.eye_ | 部分支持 | 暂不支持图模式 |
| nn.init.dirac_ | 部分支持 | 暂不支持图模式 |
| nn.init.orthogonal_ | 部分支持 | 暂不支持图模式 |
| nn.parallel.DistributedDataParallel | 部分支持 | [功能存在限制](ConstraintList.md) |

### <span id="jump5">nn.functional</span>
| MSAdapter接口 | 状态 | 约束 |
@@ -1251,4 +1253,21 @@
| BatchSampler | 支持 | |
| distributed.DistributedSampler | 支持 | |

### <span id="jump10">torch.distributed</span>
<span id="jump10">分布式统一约束:</span>
- 在Ascend后端上,由于设备差异,NCCL相关接口默认会被替换为HCCL相关接口。
panshaowu marked this conversation as resolved
zoulq commented 9 months ago
Review
当前这个接口相关的写法样例没有完善,在这里加上一句话“torch.distributed相关接口为实验性API, 后续可能修改或删除。分布式训练功能迁移请参考用户手册样例描述。” 然后链接到https://openi.pcl.ac.cn/OpenI/MSAdapter/src/branch/master/USER_GUIDE.md#user-content-3-3-%E4%BD%BF%E7%94%A8%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83%E5%8A%A0%E9%80%9F%E8%AE%AD%E7%BB%83
- torch.distributed相关接口为实验性API,后续可能修改或删除。分布式训练功能迁移请参考[用户手册样例描述](USER_GUIDE.md)。

| MSAdapter接口 | 状态 | 约束 |
| --------------- | ---- |------------------------------|
| init_process_group | 支持 | [功能存在限制](ConstraintList.md) |
| get_rank | 支持 | |
| new_group | 部分支持 | [功能存在限制](ConstraintList.md) |
| get_world_size | 支持 | |
| destroy_process_group | 支持 | |
| is_available | 支持 | |
| is_initialized | 支持 | |
| is_mpi_available | 支持 | |
| is_nccl_available | 支持 | |
| get_backend | 支持 | |
| get_process_group_ranks | 支持 | |

+ 21
- 0
SupportedList_en.md View File

@@ -8,6 +8,7 @@ English | [简体中文](SupportedList.md)
- [torch.linalg](#jump6)
- [torch.optim](#jump7)
- [torch.utils.data](#jump9)
- [torch.distributed](#jump10)

### <span id="jump8">General Constraint</span>
- Not support the function of configuration `layout`, `device`, `requires_grad`, `memory_format`.
@@ -1020,6 +1021,7 @@ English | [简体中文](SupportedList.md)
| nn.init.eye_ | Partly supported | Currently not support on GRAPH mode |
| nn.init.dirac_ | Partly supported | Currently not support on GRAPH mode |
| nn.init.orthogonal_ | Partly supported | Currently not support on GRAPH mode |
| nn.parallel.DistributedDataParallel | Partly supported | [Function is constrained](ConstraintList_en.md) |

### <span id="jump5">nn.functional</span>
| MSAdapter APIs | Status | Restrictions |
@@ -1252,3 +1254,22 @@ English | [简体中文](SupportedList.md)
| WeightedRandomSampler | Supported |[Function is constrained](ConstraintList_en.md)|
| BatchSampler | Supported | |
| distributed.DistributedSampler | Supported | |

### <span id="jump10">torch.distributed</span>
<span id="jump11">distributed General Constraints:</span>
- In Ascend backend, NCCL related interfaces are replaced by HCCL related interfaces by default due to device differences.
- The torch.distributed related interface is an experimental API that may be modified or deleted in the future. Please refer to [User Manual Sample Description](USER_GUIDE.md) for the migration of distributed training functions.

| MSAdapter APIs | Status | Restrictions |
| --------------- | ---- |------------------------------|
| init_process_group | Partly supported | [Function is constrained](ConstraintList_en.md) |
| get_rank | Supported | |
| new_group | Partly supported | [Function is constrained](ConstraintList_en.md) |
| get_world_size | Supported | |
| destroy_process_group | Supported | |
| is_available | Supported | |
| is_initialized | Supported | |
| is_mpi_available | Supported | |
| is_nccl_available | Supported | |
| get_backend | Supported | |
| get_process_group_ranks | Supported | |

+ 18
- 0
msadapter/pytorch/distributed/__init__.py View File

@@ -0,0 +1,18 @@
from .distributed_c10d import (init_process_group, get_rank, new_group, get_world_size, ProcessGroup,
destroy_process_group, is_available, is_initialized, is_mpi_available,
is_nccl_available, is_hccl_available, get_backend, get_process_group_ranks)

__all__ = [
'init_process_group',
'get_rank',
'new_group',
'get_world_size',
'ProcessGroup',
'destroy_process_group',
'is_available',
'is_initialized',
'is_nccl_available',
'is_hccl_available',
'get_backend',
'get_process_group_ranks',
]

+ 334
- 0
msadapter/pytorch/distributed/distributed_c10d.py View File

@@ -0,0 +1,334 @@
from datetime import timedelta
from typing import Any, Dict, Optional, Union, overload, List

from mindspore.context import ParallelMode
from mindspore.communication.management import init, create_group, GlobalComm, destroy_group
from mindspore.communication._comm_helper import (_is_available, _is_initialized, _get_backend, _is_hccl_available,
_is_nccl_available, _is_mpi_available, _get_group_ranks)
from mindspore import Tensor
import mindspore as ms
from mindspore import log as logger

from msadapter.utils import unsupported_attr

BACKEND_DEVICE_TARGET_DICT = {
'nccl': 'GPU',
'hccl': 'Ascend',
}


class ReduceOp:
SUM = ...
PRODUCT = ...
MIN = ...
MAX = ...
BAND = ...
BOR = ...
BXOR = ...
PREMUL_SUM = ...
UNUSED = ...


class ProcessGroup:
# TODO: implemented the following methods after the operators supported.
@overload
def broadcast(
self,
tensors: List[Tensor],
opts=None,
) -> None: ...

@overload
def broadcast(
self,
tensor: Tensor,
root: int,
) -> None: ...

@overload
def allreduce(
self,
tensors: List[Tensor],
opts=None,
) -> None: ...

@overload
def allreduce(
self,
tensors: List[Tensor],
op=ReduceOp.SUM,
) -> None: ...

@overload
def allreduce(
self,
tensor: Tensor,
op=ReduceOp.SUM,
) -> None: ...

@overload
def reduce(
self,
tensors: List[Tensor],
opts=None,
) -> None: ...

@overload
def reduce(
self,
tensor: Tensor,
root: int,
op=ReduceOp.SUM,
) -> None: ...

@overload
def allgather(
self,
output_tensors: List[List[Tensor]],
input_tensors: List[Tensor],
opts=None,
) -> None: ...

@overload
def allgather(
self,
output_tensors: List[Tensor],
input_tensor: Tensor,
) -> None: ...

@overload
def gather(
self,
output_tensors: List[List[Tensor]],
input_tensors: List[Tensor],
opts=None,
) -> None: ...

@overload
def gather(
self,
output_tensors: List[Tensor],
input_tensor: Tensor,
root: int,
) -> None: ...

@overload
def scatter(
self,
output_tensors: List[Tensor],
input_tensors: List[List[Tensor]],
opts=None,
) -> None: ...

@overload
def scatter(
self,
output_tensor: Tensor,
input_tensors: List[Tensor],
root: int,
) -> None: ...

@overload
def reduce_scatter(
self,
output_tensors: List[Tensor],
input_tensors: List[List[Tensor]],
opts=None,
) -> None: ...

@overload
def reduce_scatter(
self,
output_tensors: Tensor,
input_tensor: List[Tensor],
) -> None: ...

@overload
def alltoall_base(
self,
output_tensor: Tensor,
input_tensor: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
opts=None,
) -> None: ...

@overload
def alltoall_base(
self,
output: Tensor,
input: Tensor,
output_split_sizes: List[int],
input_split_sizes: List[int],
) -> None: ...

@overload
def alltoall(
self,
output_tensor: List[Tensor],
input_tensor: List[Tensor],
opts=None,
) -> None: ...

@overload
def alltoall(
self,
output: List[Tensor],
input: List[Tensor],
) -> None: ...

@overload
def send(
self,
tensors: List[Tensor],
dstRank: int,
tag: int,
) -> None: ...

@overload
def recv(
self,
tensors: List[Tensor],
srcRank: int,
tag: int,
) -> None: ...


_pg_map: Dict[ProcessGroup, str] = {}


def init_process_group(
backend: Union[str],
init_method: Optional[str] = None,
timeout: timedelta = None,
world_size: int = -1,
rank: int = -1,
store: Optional = None,
group_name: str = "",
pg_options: Optional[Any] = None,
):
global _pg_map
if backend not in BACKEND_DEVICE_TARGET_DICT:
raise ValueError('{} is not supported.'.format(backend))
device_target = ms.get_context('device_target')
backend_device_target = BACKEND_DEVICE_TARGET_DICT[backend]
if device_target != backend_device_target:
raise ValueError('If backend is {}, the device_target must be {}.'.format(backend, backend_device_target))
init()
device_num = get_world_size()
if world_size != device_num:
raise ValueError('The world_size:{} is not equal to the device_num:{}.'.format(world_size, device_num))
ms.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num,
parameter_broadcast=True)

pg = ProcessGroup()
_pg_map[pg] = GlobalComm.WORLD_COMM_GROUP

unsupported_attr(init_method)
unsupported_attr(timeout)
unsupported_attr(rank)
unsupported_attr(store)
unsupported_attr(group_name)
unsupported_attr(pg_options)


def new_group(ranks: Optional[List[int]] = None,
timeout: Optional[timedelta] = None,
backend: Optional[str] = None,
pg_options: Optional[Any] = None):
global _pg_map
if ranks is None:
return None
if not isinstance(ranks, list):
raise TypeError("The dtype of ranks must be `list`, but got `{}`".format(type(ranks)))
if len(ranks) == get_world_size():
return None
for i, rank in enumerate(ranks):
if not isinstance(rank, int):
raise TypeError("The dtype of ranks[{}] must be `int`, but got `{}`".format(i, type(rank)))
rank = get_rank()
if rank not in ranks:
return None
pg = ProcessGroup()
name = 'group_{}'.format(len(_pg_map))
create_group(name, ranks)
_pg_map[pg] = name
unsupported_attr(timeout)
unsupported_attr(backend)
unsupported_attr(pg_options)
return pg


def get_rank(group: Optional[ProcessGroup] = None):
group = GlobalComm.WORLD_COMM_GROUP if group is None else _pg_map[group]
return ms.communication.get_rank(group)


def get_world_size(group: Optional[ProcessGroup] = None):
group = GlobalComm.WORLD_COMM_GROUP if group is None else _pg_map[group]
return ms.communication.get_group_size(group)


def destroy_process_group(group: Optional[ProcessGroup] = None):
global _pg_map
if group is None:
del_pg_list = list()
for pg in _pg_map:
name = _pg_map[pg]
if name != GlobalComm.WORLD_COMM_GROUP:
destroy_group(name)
del_pg_list.append(pg)
for pg in del_pg_list:
del _pg_map[pg]
else:
if group in _pg_map:
name = _pg_map[group]
if name != GlobalComm.WORLD_COMM_GROUP:
destroy_group(name)
del _pg_map[group]


def _get_pg_name(group: Union[ProcessGroup, None]):
if group is None:
return GlobalComm.WORLD_COMM_GROUP
if isinstance(group, ProcessGroup):
if group in _pg_map:
return _pg_map[group]
raise ValueError("The `group` is not existed.")
raise TypeError('The dtype of `group` must be `ProcessGroup`, but got {}'.format(type(group)))


def is_available():
return _is_available()


def is_initialized():
return _is_initialized()


def is_mpi_available():
return _is_mpi_available()


def is_nccl_available():
device_target = ms.get_context('device_target')
if device_target == 'Ascend':
logger.warning("In Ascend, the result of is_hccl_available() is returned. "
"If you do not want to see this log, please use that API.")
return _is_hccl_available()
return _is_nccl_available()


def is_hccl_available():
return _is_hccl_available()


def get_backend():
return _get_backend()


def get_process_group_ranks(group: Union[ProcessGroup, None]):
if group is None:
return _get_group_ranks(GlobalComm.WORLD_COMM_GROUP)
pg_name = _pg_map[group]
return _get_group_ranks(pg_name)

+ 1
- 0
msadapter/pytorch/nn/__init__.py View File

@@ -6,3 +6,4 @@ from .parameter import Parameter, ParameterTuple
from . import init
from . import functional
from . import utils
from . import parallel

+ 5
- 0
msadapter/pytorch/nn/parallel/__init__.py View File

@@ -0,0 +1,5 @@
from .distributed import DistributedDataParallel

__all__ = [
'DistributedDataParallel'
]

+ 80
- 0
msadapter/pytorch/nn/parallel/distributed.py View File

@@ -0,0 +1,80 @@
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import GlobalComm
from mindspore import ops
import mindspore as ms
from mindspore import log as logger

from msadapter.pytorch import distributed as dist
from msadapter.pytorch.distributed.distributed_c10d import _get_pg_name
from msadapter.pytorch.nn.modules.module import Module
from msadapter.utils import unsupported_attr


class DistributedDataParallel(Module):
def __init__(
self,
module,
device_ids=None,
output_device=None,
dim=0,
broadcast_buffers=True,
process_group=None,
bucket_cap_mb=25,
find_unused_parameters=False,
check_reduction=False,
gradient_as_bucket_view=False,
static_graph=False,
):
super(DistributedDataParallel, self).__init__()
ms.set_auto_parallel_context(comm_fusion={"allreduce": {"mode": "size", "config": bucket_cap_mb}})
if ms.get_context('mode') == ms.PYNATIVE_MODE:
logger.warning("`bucket_cap_mb` takes effect only in graph mode.")

self.network = module
device_num = dist.get_world_size(process_group)
pg_name = GlobalComm.WORLD_COMM_GROUP if process_group is None else _get_pg_name(process_group)
self.grad_reducer = DistributedGradReducer(module.trainable_params(), degree=device_num, group=pg_name)

self.modules_buffers = list()
self.broadcast_buffers = broadcast_buffers
if broadcast_buffers:
for param in module.get_parameters():
if not param.requires_grad:
self.modules_buffers.append(param)
self.broadcast = ops.Broadcast(0, pg_name)

unsupported_attr(device_ids)
unsupported_attr(output_device)
unsupported_attr(dim)
unsupported_attr(find_unused_parameters)
unsupported_attr(check_reduction)
unsupported_attr(gradient_as_bucket_view)
unsupported_attr(static_graph)

def will_sync_module_buffers(self):
return self.broadcast_buffers and len(self.modules_buffers) > 0

def _sync_buffers(self):
for buffer in self.modules_buffers:
remote_buffer = self.broadcast(buffer)
buffer.set_data(remote_buffer)

def forward(self, *inputs, **kwargs):
if self.will_sync_module_buffers():
self._sync_buffers()
self.network(*inputs, **kwargs)

def all_reduce(self, grads):
grads = self.grad_reducer(grads)
return grads

def gather(self, outputs, output_device):
# TODO: implemented the method after the operators supported.
unsupported_attr(outputs)
unsupported_attr(output_device)

def scatter(self, inputs, kwargs, device_ids):
# TODO: implemented the method after the operators supported.
unsupported_attr(inputs)
unsupported_attr(kwargs)
unsupported_attr(device_ids)

+ 0
- 0
testing/ut/pytorch/distributed/__init__.py View File


+ 49
- 0
testing/ut/pytorch/distributed/ddp_impl.py View File

@@ -0,0 +1,49 @@
import numpy as np

import mindspore as ms
from mindspore import nn

from msadapter.pytorch.nn.parallel import DistributedDataParallel as DDP
from msadapter.pytorch import distributed as dist


class NetWork(nn.Cell):
def __init__(self):
super(NetWork, self).__init__()
self.dense = nn.Dense(3, 3)

def construct(self, x):
return self.dense(x).sum()


def ddp_basic():
# you can use the following bash command to see the print log
# mpirun --allow-run-as-root -n 3 python ddp_impl.py
dist.init_process_group(backend='nccl', rank=0, world_size=3)

network = NetWork()
opt = nn.Adam(network.trainable_params())
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False)

rank = dist.get_rank()
network = DDP(network, device_ids=[rank], find_unused_parameters=False)
ranks = [0, 2]
if rank in ranks:
pg = dist.new_group(ranks)
network_p = DDP(network, device_ids=[rank], process_group=pg)
else:
network_p = None

inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32))
for _ in range(1):
loss, grads = grad_fn(inputs)
grads = network.all_reduce(grads)
opt(grads)
if network_p is not None:
grads = network_p.all_reduce(grads)
opt(grads)
print('rank:', rank, ', loss:', loss)


if __name__ == '__main__':
ddp_basic()

+ 16
- 0
testing/ut/pytorch/distributed/test_ddp.py View File

@@ -0,0 +1,16 @@
import os
import subprocess

from ...utils import SKIP_ENV_CPU, set_mode_by_env_config
set_mode_by_env_config()


@SKIP_ENV_CPU(reason='`DistributedDataParallel` is not supported on CPU')
def test_ddp_basic():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 3 '
cmd += 'python {}/ddp_impl.py'.format(cur_dir)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

+ 36
- 0
testing/ut/pytorch/nn/dist_impl.py View File

@@ -0,0 +1,36 @@
from msadapter.pytorch.distributed.distributed_c10d import _get_pg_name
from msadapter.pytorch import distributed as dist


def dist_basic():
# you can use the following bash command to see the print log
# mpirun --allow-run-as-root -n 3 python dist_impl.py
dist.init_process_group(backend='nccl', world_size=3)
pg0 = dist.new_group([0, 1])
pg1 = dist.new_group([1, 2])
pg2 = dist.new_group([0, 1, 2])
print('rank:', dist.get_rank())
print('world_size:', dist.get_world_size())
print('pg0:', _get_pg_name(pg0))
print('pg1:', _get_pg_name(pg1))
print('pg2:', _get_pg_name(pg2))
print('pg0 rank:', dist.get_rank(pg0))
print('pg1 rank:', dist.get_rank(pg1))
print('pg2 rank:', dist.get_rank(pg2))
print('is_available:', dist.is_available())
print('is_initialized:', dist.is_initialized())
print('is_mpi_available:', dist.is_mpi_available())
print('is_nccl_available:', dist.is_nccl_available())
print('is_hccl_available:', dist.is_hccl_available())
print('get_backend:', dist.get_backend())
print('pg0 get_process_group_ranks:', dist.get_process_group_ranks(pg0))
print('pg1 get_process_group_ranks:', dist.get_process_group_ranks(pg1))
print('pg2 get_process_group_ranks:', dist.get_process_group_ranks(pg2))

dist.destroy_process_group(pg0)
dist.destroy_process_group(pg1)
dist.destroy_process_group(pg2)


if __name__ == '__main__':
dist_basic()

+ 16
- 0
testing/ut/pytorch/nn/test_dist.py View File

@@ -0,0 +1,16 @@
import os
import subprocess

from ...utils import SKIP_ENV_CPU, set_mode_by_env_config
set_mode_by_env_config()


@SKIP_ENV_CPU(reason='`distributed` is not supported on CPU')
def test_dist_basic():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 3 '
cmd += 'python {}/dist_impl.py'.format(cur_dir)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

Loading…
Cancel
Save