同上
@@ -6,6 +6,7 @@ | |||
- [nn.functional](#jump5) | |||
- [torch.linalg](#jump6) | |||
- [torch.utils.data](#jump7) | |||
- [torch.distributed](#jump8) | |||
## <span id="jump1">接口约束列表</span> | |||
@@ -246,6 +247,7 @@ | |||
| nn.RNN | 在图模式下,`input`不支持PackedSequence类型 | | |||
| nn.GRU | 在图模式下,`input`不支持PackedSequence类型 | | |||
| nn.CrossEntropyLoss | `target`类型为int64时,有溢出风险 | | |||
| nn.parallel.DistributedDataParallel | 不支持`device_ids`, `output_device`, `dim`, `find_unused_parameters`, `check_reduction`, `gradient_as_bucket_view`, `static_graph` | | |||
### <span id="jump5">nn.functional</span> | |||
| MSAdapter接口 | 约束条件 | | |||
@@ -310,3 +312,9 @@ | |||
| RandomSampler | 暂不支持传入Generator| | |||
| SubsetRandomSampler | 暂不支持传入Generator| | |||
| WeightedRandomSampler | 暂不支持传入Generator| | |||
### <span id="jump8">torch.distributed</span> | |||
| MSAdapter接口 | 约束条件 | | |||
| --------------- |-----------------------------------------------------------------------------------------| | |||
| init_process_group | 不支持`timeout`, `rank`, `store`, `group_name`, `pg_options`,部分支持`init_method`:以环境变量模式配置初始化 | | |||
| new_group | 不支持`timeout`, `pg_options`,`backend`部分支持(nccl) | | |||
panshaowu marked this conversation as resolved
|
@@ -7,6 +7,7 @@ English | [简体中文](ConstraintList.md) | |||
- [nn.functional](#jump5) | |||
- [torch.linalg](#jump6) | |||
- [torch.utils.data](#jump7) | |||
- [torch.distributed](#jump8) | |||
## <span id="jump1">API Constraints List</span> | |||
@@ -247,6 +248,7 @@ English | [简体中文](ConstraintList.md) | |||
| nn.RNN | Under GRAPH mode, `input` not support PackedSequence type | | |||
| nn.GRU | Under GRAPH mode, `input` not support PackedSequence type | | |||
| nn.CrossEntropyLoss | There is risk of overflow when `target` type is int64 | | |||
| nn.parallel.DistributedDataParallel | `device_ids`, `output_device`, `dim`, `find_unused_parameters`, `check_reduction`, `gradient_as_bucket_view`, `static_graph` are not supported| | |||
### <span id="jump5">nn.functional</span> | |||
| MSAdapter APIs | Constraint conditions | | |||
@@ -311,3 +313,9 @@ English | [简体中文](ConstraintList.md) | |||
| RandomSampler | Currently not support input Generator | | |||
| SubsetRandomSampler | Currently not support input Generator | | |||
| WeightedRandomSampler | Currently not support input Generator | | |||
### <span id="jump8">torch.distributed</span> | |||
| MSAdapter APIs | Constraint conditions | | |||
| --------------- |-----------------------------------------------------------------------------------------| | |||
| init_process_group | `timeout`, `rank`, `store`, `group_name`, `pg_options` are not supported, `init_method` is partly supported: initialization can be configured only in environment variable mode. | | |||
| new_group | `timeout`, `pg_options` are not supported, `backend` is partly supported (nccl) | | |||
panshaowu marked this conversation as resolved
Erpim commented 9 months ago
Review
同上 同上
|
@@ -9,6 +9,7 @@ | |||
- [torch.linalg](#jump6) | |||
- [torch.optim](#jump7) | |||
- [torch.utils.data](#jump9) | |||
- [torch.distributed](#jump10) | |||
### <span id="jump8">通用限制</span> | |||
@@ -1022,6 +1023,7 @@ | |||
| nn.init.eye_ | 部分支持 | 暂不支持图模式 | | |||
| nn.init.dirac_ | 部分支持 | 暂不支持图模式 | | |||
| nn.init.orthogonal_ | 部分支持 | 暂不支持图模式 | | |||
| nn.parallel.DistributedDataParallel | 部分支持 | [功能存在限制](ConstraintList.md) | | |||
### <span id="jump5">nn.functional</span> | |||
| MSAdapter接口 | 状态 | 约束 | | |||
@@ -1251,4 +1253,21 @@ | |||
| BatchSampler | 支持 | | | |||
| distributed.DistributedSampler | 支持 | | | |||
### <span id="jump10">torch.distributed</span> | |||
<span id="jump10">分布式统一约束:</span> | |||
- 在Ascend后端上,由于设备差异,NCCL相关接口默认会被替换为HCCL相关接口。 | |||
panshaowu marked this conversation as resolved
zoulq commented 9 months ago
Review
当前这个接口相关的写法样例没有完善,在这里加上一句话“torch.distributed相关接口为实验性API, 后续可能修改或删除。分布式训练功能迁移请参考用户手册样例描述。” 当前这个接口相关的写法样例没有完善,在这里加上一句话“torch.distributed相关接口为实验性API, 后续可能修改或删除。分布式训练功能迁移请参考用户手册样例描述。”
然后链接到https://openi.pcl.ac.cn/OpenI/MSAdapter/src/branch/master/USER_GUIDE.md#user-content-3-3-%E4%BD%BF%E7%94%A8%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83%E5%8A%A0%E9%80%9F%E8%AE%AD%E7%BB%83
|
|||
- torch.distributed相关接口为实验性API,后续可能修改或删除。分布式训练功能迁移请参考[用户手册样例描述](USER_GUIDE.md)。 | |||
| MSAdapter接口 | 状态 | 约束 | | |||
| --------------- | ---- |------------------------------| | |||
| init_process_group | 支持 | [功能存在限制](ConstraintList.md) | | |||
| get_rank | 支持 | | | |||
| new_group | 部分支持 | [功能存在限制](ConstraintList.md) | | |||
| get_world_size | 支持 | | | |||
| destroy_process_group | 支持 | | | |||
| is_available | 支持 | | | |||
| is_initialized | 支持 | | | |||
| is_mpi_available | 支持 | | | |||
| is_nccl_available | 支持 | | | |||
| get_backend | 支持 | | | |||
| get_process_group_ranks | 支持 | | |
@@ -8,6 +8,7 @@ English | [简体中文](SupportedList.md) | |||
- [torch.linalg](#jump6) | |||
- [torch.optim](#jump7) | |||
- [torch.utils.data](#jump9) | |||
- [torch.distributed](#jump10) | |||
### <span id="jump8">General Constraint</span> | |||
- Not support the function of configuration `layout`, `device`, `requires_grad`, `memory_format`. | |||
@@ -1020,6 +1021,7 @@ English | [简体中文](SupportedList.md) | |||
| nn.init.eye_ | Partly supported | Currently not support on GRAPH mode | | |||
| nn.init.dirac_ | Partly supported | Currently not support on GRAPH mode | | |||
| nn.init.orthogonal_ | Partly supported | Currently not support on GRAPH mode | | |||
| nn.parallel.DistributedDataParallel | Partly supported | [Function is constrained](ConstraintList_en.md) | | |||
### <span id="jump5">nn.functional</span> | |||
| MSAdapter APIs | Status | Restrictions | | |||
@@ -1252,3 +1254,22 @@ English | [简体中文](SupportedList.md) | |||
| WeightedRandomSampler | Supported |[Function is constrained](ConstraintList_en.md)| | |||
| BatchSampler | Supported | | | |||
| distributed.DistributedSampler | Supported | | | |||
### <span id="jump10">torch.distributed</span> | |||
<span id="jump11">distributed General Constraints:</span> | |||
- In Ascend backend, NCCL related interfaces are replaced by HCCL related interfaces by default due to device differences. | |||
- The torch.distributed related interface is an experimental API that may be modified or deleted in the future. Please refer to [User Manual Sample Description](USER_GUIDE.md) for the migration of distributed training functions. | |||
| MSAdapter APIs | Status | Restrictions | | |||
| --------------- | ---- |------------------------------| | |||
| init_process_group | Partly supported | [Function is constrained](ConstraintList_en.md) | | |||
| get_rank | Supported | | | |||
| new_group | Partly supported | [Function is constrained](ConstraintList_en.md) | | |||
| get_world_size | Supported | | | |||
| destroy_process_group | Supported | | | |||
| is_available | Supported | | | |||
| is_initialized | Supported | | | |||
| is_mpi_available | Supported | | | |||
| is_nccl_available | Supported | | | |||
| get_backend | Supported | | | |||
| get_process_group_ranks | Supported | | |
@@ -0,0 +1,18 @@ | |||
from .distributed_c10d import (init_process_group, get_rank, new_group, get_world_size, ProcessGroup, | |||
destroy_process_group, is_available, is_initialized, is_mpi_available, | |||
is_nccl_available, is_hccl_available, get_backend, get_process_group_ranks) | |||
__all__ = [ | |||
'init_process_group', | |||
'get_rank', | |||
'new_group', | |||
'get_world_size', | |||
'ProcessGroup', | |||
'destroy_process_group', | |||
'is_available', | |||
'is_initialized', | |||
'is_nccl_available', | |||
'is_hccl_available', | |||
'get_backend', | |||
'get_process_group_ranks', | |||
] |
@@ -0,0 +1,334 @@ | |||
from datetime import timedelta | |||
from typing import Any, Dict, Optional, Union, overload, List | |||
from mindspore.context import ParallelMode | |||
from mindspore.communication.management import init, create_group, GlobalComm, destroy_group | |||
from mindspore.communication._comm_helper import (_is_available, _is_initialized, _get_backend, _is_hccl_available, | |||
_is_nccl_available, _is_mpi_available, _get_group_ranks) | |||
from mindspore import Tensor | |||
import mindspore as ms | |||
from mindspore import log as logger | |||
from msadapter.utils import unsupported_attr | |||
BACKEND_DEVICE_TARGET_DICT = { | |||
'nccl': 'GPU', | |||
'hccl': 'Ascend', | |||
} | |||
class ReduceOp: | |||
SUM = ... | |||
PRODUCT = ... | |||
MIN = ... | |||
MAX = ... | |||
BAND = ... | |||
BOR = ... | |||
BXOR = ... | |||
PREMUL_SUM = ... | |||
UNUSED = ... | |||
class ProcessGroup: | |||
# TODO: implemented the following methods after the operators supported. | |||
@overload | |||
def broadcast( | |||
self, | |||
tensors: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def broadcast( | |||
self, | |||
tensor: Tensor, | |||
root: int, | |||
) -> None: ... | |||
@overload | |||
def allreduce( | |||
self, | |||
tensors: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def allreduce( | |||
self, | |||
tensors: List[Tensor], | |||
op=ReduceOp.SUM, | |||
) -> None: ... | |||
@overload | |||
def allreduce( | |||
self, | |||
tensor: Tensor, | |||
op=ReduceOp.SUM, | |||
) -> None: ... | |||
@overload | |||
def reduce( | |||
self, | |||
tensors: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def reduce( | |||
self, | |||
tensor: Tensor, | |||
root: int, | |||
op=ReduceOp.SUM, | |||
) -> None: ... | |||
@overload | |||
def allgather( | |||
self, | |||
output_tensors: List[List[Tensor]], | |||
input_tensors: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def allgather( | |||
self, | |||
output_tensors: List[Tensor], | |||
input_tensor: Tensor, | |||
) -> None: ... | |||
@overload | |||
def gather( | |||
self, | |||
output_tensors: List[List[Tensor]], | |||
input_tensors: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def gather( | |||
self, | |||
output_tensors: List[Tensor], | |||
input_tensor: Tensor, | |||
root: int, | |||
) -> None: ... | |||
@overload | |||
def scatter( | |||
self, | |||
output_tensors: List[Tensor], | |||
input_tensors: List[List[Tensor]], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def scatter( | |||
self, | |||
output_tensor: Tensor, | |||
input_tensors: List[Tensor], | |||
root: int, | |||
) -> None: ... | |||
@overload | |||
def reduce_scatter( | |||
self, | |||
output_tensors: List[Tensor], | |||
input_tensors: List[List[Tensor]], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def reduce_scatter( | |||
self, | |||
output_tensors: Tensor, | |||
input_tensor: List[Tensor], | |||
) -> None: ... | |||
@overload | |||
def alltoall_base( | |||
self, | |||
output_tensor: Tensor, | |||
input_tensor: Tensor, | |||
output_split_sizes: List[int], | |||
input_split_sizes: List[int], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def alltoall_base( | |||
self, | |||
output: Tensor, | |||
input: Tensor, | |||
output_split_sizes: List[int], | |||
input_split_sizes: List[int], | |||
) -> None: ... | |||
@overload | |||
def alltoall( | |||
self, | |||
output_tensor: List[Tensor], | |||
input_tensor: List[Tensor], | |||
opts=None, | |||
) -> None: ... | |||
@overload | |||
def alltoall( | |||
self, | |||
output: List[Tensor], | |||
input: List[Tensor], | |||
) -> None: ... | |||
@overload | |||
def send( | |||
self, | |||
tensors: List[Tensor], | |||
dstRank: int, | |||
tag: int, | |||
) -> None: ... | |||
@overload | |||
def recv( | |||
self, | |||
tensors: List[Tensor], | |||
srcRank: int, | |||
tag: int, | |||
) -> None: ... | |||
_pg_map: Dict[ProcessGroup, str] = {} | |||
def init_process_group( | |||
backend: Union[str], | |||
init_method: Optional[str] = None, | |||
timeout: timedelta = None, | |||
world_size: int = -1, | |||
rank: int = -1, | |||
store: Optional = None, | |||
group_name: str = "", | |||
pg_options: Optional[Any] = None, | |||
): | |||
global _pg_map | |||
if backend not in BACKEND_DEVICE_TARGET_DICT: | |||
raise ValueError('{} is not supported.'.format(backend)) | |||
device_target = ms.get_context('device_target') | |||
backend_device_target = BACKEND_DEVICE_TARGET_DICT[backend] | |||
if device_target != backend_device_target: | |||
raise ValueError('If backend is {}, the device_target must be {}.'.format(backend, backend_device_target)) | |||
init() | |||
device_num = get_world_size() | |||
if world_size != device_num: | |||
raise ValueError('The world_size:{} is not equal to the device_num:{}.'.format(world_size, device_num)) | |||
ms.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num, | |||
parameter_broadcast=True) | |||
pg = ProcessGroup() | |||
_pg_map[pg] = GlobalComm.WORLD_COMM_GROUP | |||
unsupported_attr(init_method) | |||
unsupported_attr(timeout) | |||
unsupported_attr(rank) | |||
unsupported_attr(store) | |||
unsupported_attr(group_name) | |||
unsupported_attr(pg_options) | |||
def new_group(ranks: Optional[List[int]] = None, | |||
timeout: Optional[timedelta] = None, | |||
backend: Optional[str] = None, | |||
pg_options: Optional[Any] = None): | |||
global _pg_map | |||
if ranks is None: | |||
return None | |||
if not isinstance(ranks, list): | |||
raise TypeError("The dtype of ranks must be `list`, but got `{}`".format(type(ranks))) | |||
if len(ranks) == get_world_size(): | |||
return None | |||
for i, rank in enumerate(ranks): | |||
if not isinstance(rank, int): | |||
raise TypeError("The dtype of ranks[{}] must be `int`, but got `{}`".format(i, type(rank))) | |||
rank = get_rank() | |||
if rank not in ranks: | |||
return None | |||
pg = ProcessGroup() | |||
name = 'group_{}'.format(len(_pg_map)) | |||
create_group(name, ranks) | |||
_pg_map[pg] = name | |||
unsupported_attr(timeout) | |||
unsupported_attr(backend) | |||
unsupported_attr(pg_options) | |||
return pg | |||
def get_rank(group: Optional[ProcessGroup] = None): | |||
group = GlobalComm.WORLD_COMM_GROUP if group is None else _pg_map[group] | |||
return ms.communication.get_rank(group) | |||
def get_world_size(group: Optional[ProcessGroup] = None): | |||
group = GlobalComm.WORLD_COMM_GROUP if group is None else _pg_map[group] | |||
return ms.communication.get_group_size(group) | |||
def destroy_process_group(group: Optional[ProcessGroup] = None): | |||
global _pg_map | |||
if group is None: | |||
del_pg_list = list() | |||
for pg in _pg_map: | |||
name = _pg_map[pg] | |||
if name != GlobalComm.WORLD_COMM_GROUP: | |||
destroy_group(name) | |||
del_pg_list.append(pg) | |||
for pg in del_pg_list: | |||
del _pg_map[pg] | |||
else: | |||
if group in _pg_map: | |||
name = _pg_map[group] | |||
if name != GlobalComm.WORLD_COMM_GROUP: | |||
destroy_group(name) | |||
del _pg_map[group] | |||
def _get_pg_name(group: Union[ProcessGroup, None]): | |||
if group is None: | |||
return GlobalComm.WORLD_COMM_GROUP | |||
if isinstance(group, ProcessGroup): | |||
if group in _pg_map: | |||
return _pg_map[group] | |||
raise ValueError("The `group` is not existed.") | |||
raise TypeError('The dtype of `group` must be `ProcessGroup`, but got {}'.format(type(group))) | |||
def is_available(): | |||
return _is_available() | |||
def is_initialized(): | |||
return _is_initialized() | |||
def is_mpi_available(): | |||
return _is_mpi_available() | |||
def is_nccl_available(): | |||
device_target = ms.get_context('device_target') | |||
if device_target == 'Ascend': | |||
logger.warning("In Ascend, the result of is_hccl_available() is returned. " | |||
"If you do not want to see this log, please use that API.") | |||
return _is_hccl_available() | |||
return _is_nccl_available() | |||
def is_hccl_available(): | |||
return _is_hccl_available() | |||
def get_backend(): | |||
return _get_backend() | |||
def get_process_group_ranks(group: Union[ProcessGroup, None]): | |||
if group is None: | |||
return _get_group_ranks(GlobalComm.WORLD_COMM_GROUP) | |||
pg_name = _pg_map[group] | |||
return _get_group_ranks(pg_name) |
@@ -6,3 +6,4 @@ from .parameter import Parameter, ParameterTuple | |||
from . import init | |||
from . import functional | |||
from . import utils | |||
from . import parallel |
@@ -0,0 +1,5 @@ | |||
from .distributed import DistributedDataParallel | |||
__all__ = [ | |||
'DistributedDataParallel' | |||
] |
@@ -0,0 +1,80 @@ | |||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
from mindspore.communication.management import GlobalComm | |||
from mindspore import ops | |||
import mindspore as ms | |||
from mindspore import log as logger | |||
from msadapter.pytorch import distributed as dist | |||
from msadapter.pytorch.distributed.distributed_c10d import _get_pg_name | |||
from msadapter.pytorch.nn.modules.module import Module | |||
from msadapter.utils import unsupported_attr | |||
class DistributedDataParallel(Module): | |||
def __init__( | |||
self, | |||
module, | |||
device_ids=None, | |||
output_device=None, | |||
dim=0, | |||
broadcast_buffers=True, | |||
process_group=None, | |||
bucket_cap_mb=25, | |||
find_unused_parameters=False, | |||
check_reduction=False, | |||
gradient_as_bucket_view=False, | |||
static_graph=False, | |||
): | |||
super(DistributedDataParallel, self).__init__() | |||
ms.set_auto_parallel_context(comm_fusion={"allreduce": {"mode": "size", "config": bucket_cap_mb}}) | |||
if ms.get_context('mode') == ms.PYNATIVE_MODE: | |||
logger.warning("`bucket_cap_mb` takes effect only in graph mode.") | |||
self.network = module | |||
device_num = dist.get_world_size(process_group) | |||
pg_name = GlobalComm.WORLD_COMM_GROUP if process_group is None else _get_pg_name(process_group) | |||
self.grad_reducer = DistributedGradReducer(module.trainable_params(), degree=device_num, group=pg_name) | |||
self.modules_buffers = list() | |||
self.broadcast_buffers = broadcast_buffers | |||
if broadcast_buffers: | |||
for param in module.get_parameters(): | |||
if not param.requires_grad: | |||
self.modules_buffers.append(param) | |||
self.broadcast = ops.Broadcast(0, pg_name) | |||
unsupported_attr(device_ids) | |||
unsupported_attr(output_device) | |||
unsupported_attr(dim) | |||
unsupported_attr(find_unused_parameters) | |||
unsupported_attr(check_reduction) | |||
unsupported_attr(gradient_as_bucket_view) | |||
unsupported_attr(static_graph) | |||
def will_sync_module_buffers(self): | |||
return self.broadcast_buffers and len(self.modules_buffers) > 0 | |||
def _sync_buffers(self): | |||
for buffer in self.modules_buffers: | |||
remote_buffer = self.broadcast(buffer) | |||
buffer.set_data(remote_buffer) | |||
def forward(self, *inputs, **kwargs): | |||
if self.will_sync_module_buffers(): | |||
self._sync_buffers() | |||
self.network(*inputs, **kwargs) | |||
def all_reduce(self, grads): | |||
grads = self.grad_reducer(grads) | |||
return grads | |||
def gather(self, outputs, output_device): | |||
# TODO: implemented the method after the operators supported. | |||
unsupported_attr(outputs) | |||
unsupported_attr(output_device) | |||
def scatter(self, inputs, kwargs, device_ids): | |||
# TODO: implemented the method after the operators supported. | |||
unsupported_attr(inputs) | |||
unsupported_attr(kwargs) | |||
unsupported_attr(device_ids) |
@@ -0,0 +1,49 @@ | |||
import numpy as np | |||
import mindspore as ms | |||
from mindspore import nn | |||
from msadapter.pytorch.nn.parallel import DistributedDataParallel as DDP | |||
from msadapter.pytorch import distributed as dist | |||
class NetWork(nn.Cell): | |||
def __init__(self): | |||
super(NetWork, self).__init__() | |||
self.dense = nn.Dense(3, 3) | |||
def construct(self, x): | |||
return self.dense(x).sum() | |||
def ddp_basic(): | |||
# you can use the following bash command to see the print log | |||
# mpirun --allow-run-as-root -n 3 python ddp_impl.py | |||
dist.init_process_group(backend='nccl', rank=0, world_size=3) | |||
network = NetWork() | |||
opt = nn.Adam(network.trainable_params()) | |||
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False) | |||
rank = dist.get_rank() | |||
network = DDP(network, device_ids=[rank], find_unused_parameters=False) | |||
ranks = [0, 2] | |||
if rank in ranks: | |||
pg = dist.new_group(ranks) | |||
network_p = DDP(network, device_ids=[rank], process_group=pg) | |||
else: | |||
network_p = None | |||
inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32)) | |||
for _ in range(1): | |||
loss, grads = grad_fn(inputs) | |||
grads = network.all_reduce(grads) | |||
opt(grads) | |||
if network_p is not None: | |||
grads = network_p.all_reduce(grads) | |||
opt(grads) | |||
print('rank:', rank, ', loss:', loss) | |||
if __name__ == '__main__': | |||
ddp_basic() |
@@ -0,0 +1,16 @@ | |||
import os | |||
import subprocess | |||
from ...utils import SKIP_ENV_CPU, set_mode_by_env_config | |||
set_mode_by_env_config() | |||
@SKIP_ENV_CPU(reason='`DistributedDataParallel` is not supported on CPU') | |||
def test_ddp_basic(): | |||
cur_dir = os.path.abspath(os.path.dirname(__file__)) | |||
cmd = 'mpirun --allow-run-as-root -n 3 ' | |||
cmd += 'python {}/ddp_impl.py'.format(cur_dir) | |||
try: | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) |
@@ -0,0 +1,36 @@ | |||
from msadapter.pytorch.distributed.distributed_c10d import _get_pg_name | |||
from msadapter.pytorch import distributed as dist | |||
def dist_basic(): | |||
# you can use the following bash command to see the print log | |||
# mpirun --allow-run-as-root -n 3 python dist_impl.py | |||
dist.init_process_group(backend='nccl', world_size=3) | |||
pg0 = dist.new_group([0, 1]) | |||
pg1 = dist.new_group([1, 2]) | |||
pg2 = dist.new_group([0, 1, 2]) | |||
print('rank:', dist.get_rank()) | |||
print('world_size:', dist.get_world_size()) | |||
print('pg0:', _get_pg_name(pg0)) | |||
print('pg1:', _get_pg_name(pg1)) | |||
print('pg2:', _get_pg_name(pg2)) | |||
print('pg0 rank:', dist.get_rank(pg0)) | |||
print('pg1 rank:', dist.get_rank(pg1)) | |||
print('pg2 rank:', dist.get_rank(pg2)) | |||
print('is_available:', dist.is_available()) | |||
print('is_initialized:', dist.is_initialized()) | |||
print('is_mpi_available:', dist.is_mpi_available()) | |||
print('is_nccl_available:', dist.is_nccl_available()) | |||
print('is_hccl_available:', dist.is_hccl_available()) | |||
print('get_backend:', dist.get_backend()) | |||
print('pg0 get_process_group_ranks:', dist.get_process_group_ranks(pg0)) | |||
print('pg1 get_process_group_ranks:', dist.get_process_group_ranks(pg1)) | |||
print('pg2 get_process_group_ranks:', dist.get_process_group_ranks(pg2)) | |||
dist.destroy_process_group(pg0) | |||
dist.destroy_process_group(pg1) | |||
dist.destroy_process_group(pg2) | |||
if __name__ == '__main__': | |||
dist_basic() |
@@ -0,0 +1,16 @@ | |||
import os | |||
import subprocess | |||
from ...utils import SKIP_ENV_CPU, set_mode_by_env_config | |||
set_mode_by_env_config() | |||
@SKIP_ENV_CPU(reason='`distributed` is not supported on CPU') | |||
def test_dist_basic(): | |||
cur_dir = os.path.abspath(os.path.dirname(__file__)) | |||
cmd = 'mpirun --allow-run-as-root -n 3 ' | |||
cmd += 'python {}/dist_impl.py'.format(cur_dir) | |||
try: | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》
pytorch 是不是不感知hccl
可以在整个torch.distributed 最上面描述下策略,比如Ascend上就默认nccl底层走hccl,或者写到通用限制里