@@ -1,4 +1,5 @@
import contextlib
import copy
from mindspore.communication.management import init, create_group, GlobalComm, destroy_group
from mindspore.communication._comm_helper import (_is_available, _is_initialized, _is_hccl_available,
@@ -6,6 +7,21 @@ from mindspore.communication._comm_helper import (_is_available, _is_initialized
import mindspore as ms
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.communication.comm_func import all_to_all_with_output_shape as _all_to_all_ms, \
all_to_all_single_with_output_shape as _all_to_all_single_ms, \
P2POp as _P2POp_ms, \
batch_isend_irecv as _batch_isend_irecv_ms, \
isend as _isend_ms, \
irecv as _irecv_ms, \
barrier as _barrier_ms, \
broadcast as _broadcast_ms, \
all_reduce as _all_reduce_ms, \
reduce as _reduce_ms, \
reduce_scatter_tensor as _reduce_scatter_ms, \
all_gather_into_tensor as _all_gather_into_tensor_ms, \
send as send_ms, \
recv as recv_ms
from mindtorch.torch.common.dtype import int8, int32, float16, float32, bfloat16, all_complex_type
from mindtorch.utils import unsupported_attr, graph_mode_condition, is_under_ascend_context, \
@@ -33,7 +49,9 @@ from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0
_get_str_from_reduceop,
_pg_map,
_pg_names,
_group_count
_group_count,
_comm_inplace_assign,
_comm_process_handle
)
_ascend_support_dtype = (int8, int32, float16, float32, bfloat16)
@@ -510,18 +528,10 @@ def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op)
else:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
result = _reduce_op(tensor)
result = _all_reduce_ms(cast_to_ms_tensor(tensor), _group_name)
return cast_to_adapter_tensor(result)
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
warning("all_reduce: 'async_op' not actually supported now. Run as sync op")
_inplace_raise_error_graph_mode('all_reduce', 'all_reduce_not_inplace')
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
@@ -534,26 +544,17 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op)
else:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _reduce_op(cast_tensor )
result, handle = _all_reduce_ms(cast_to_ms_tensor(cast_tensor), op, _group_name, async_op)
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _reduce_op(tensor)
tensor.data = result
if async_op:
return _stub_work
else:
return None
result, handle = _all_reduce_ms(cast_to_ms_tensor(tensor), op, _group_name, async_op)
_comm_inplace_assign(tensor, result)
def broadcast(tensor, src, group=None, async_op=False):
if async_op:
warning("broadcast: 'async_op' not actually supported now. Run as sync op")
return _comm_process_handle(handle, async_op)
def broadcast(tensor, src, group=None, async_op=False):
_inplace_raise_error_graph_mode('broadcast', 'broadcast_not_inplace')
_check_single_tensor(tensor, "tensor")
@@ -567,48 +568,21 @@ def broadcast(tensor, src, group=None, async_op=False):
group = _get_default_group()
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src)
else:
src = get_group_rank(group, src)
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src, _group_name)
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _bc_op((cast_tensor,))[0]
result = _broadcast_ms(cast_to_ms_tensor(cast_tensor), src, _group_name)
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _bc_op((tensor,))[0]
tensor.data = result
result = _broadcast_ms(cast_to_ms_tensor(tensor), src, _group_name)
_comm_inplace_assign(tensor, result)
if async_op:
return _stub_work
else:
return None
def all_gather_into_tensor_not_inplace(input_tensor, group=None, async_op=False):
if async_op:
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("broadcast")
return None
if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_ag_op = _get_cache_prim(ms.ops.AllGather)()
else:
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)
return cast_to_adapter_tensor(_ag_op(input_tensor))
def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
_inplace_raise_error_graph_mode('all_gather_into_tensor', 'all_gather_into_tensor_not_inplace')
if async_op:
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")
_check_single_tensor(input_tensor, "input_tensor")
_check_single_tensor(output_tensor, "output_tensor")
@@ -621,12 +595,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
group = _get_default_group()
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_ag_op = _get_cache_prim(ms.ops.AllGather)()
else:
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)
result =_ag_op(input_tensor )
result, handle = _all_gather_into_tensor_ms(cast_to_ms_tensor(input_tensor), _group_name, async_op)
_out_tensor_prim_size = output_tensor.shape[0]
_result_prim_size = result.shape[0]
@@ -639,25 +609,12 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
split_size = _result_prim_size // _out_tensor_prim_size
result = ms.ops.split(result, split_size)
result = ms.ops.stack(result)
output_tensor.data = result
if async_op:
return _stub_work
else:
return None
_comm_inplace_assign(output_tensor, result)
return _comm_process_handle(handle, async_op)
def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
def all_gather_not_inplace(tensor, group=None, async_op=False):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_gather")
return None
result = all_gather_into_tensor_not_inplace(tensor, group, async_op)
group_size = _get_group_size(group)
_split_op = _get_cache_prim(ms.ops.Split)(0, group_size)
return cast_to_adapter_tensor(_split_op(result))
def all_gather(tensor_list, tensor, group=None, async_op=False):
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
@@ -666,13 +623,17 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
# Graph mode not support code below.
# _warn_not_in_group("all_gather")
return None
result = all_gather_not_inplace(tensor, group, async_op)
for i, _tensor in enumerate(tensor_list):
_tensor.data = result[i]
if async_op:
return _stub_work
else:
return None
if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
result, handle = _all_gather_into_tensor_ms(cast_to_ms_tensor(tensor), _group_name, async_op)
group_size = _get_group_size(group)
_split_op = _get_cache_prim(ms.ops.Split)(0, group_size)
result = _split_op(result)
_comm_inplace_assign(tensor_list, result)
return _comm_process_handle(handle, async_op)
def barrier(group=None, async_op=False, device_ids=None):
if async_op:
@@ -690,50 +651,31 @@ def barrier(group=None, async_op=False, device_ids=None):
group = _get_default_group()
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)()
else:
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)(_group_name)
_barrier_op()
_barrier_ms(_group_name)
if async_op:
return _stub_work
else:
return None
def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
if async_op:
warning("all_to_all: 'async_op' not actually supported now. Run as sync op")
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_to_all_single ")
# _warn_not_in_group("all_to_all")
return None
if group is None:
group = _get_default_group()
_split_dim = 0
_concat_dim = 0
_split_count = len(input_tensor_list)
_group_name = _get_pg_name(group)
input_tensor_list = cast_to_ms_tensor(input_tensor_list)
_input_tensor = ms.ops.stack(input_tensor_list)
result, handle = _all_to_all_ms(cast_to_ms_tensor(output_tensor_list),
cast_to_ms_tensor(input_tensor_list),
_group_name,
async_op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim)
else:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim, _group_name)
_comm_inplace_assign(output_tensor_list, result)
out = _op(_input_tensor)
_spilit_size = out.shape[0] // _split_count
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(output_tensor_list):
output.data = out[i]
if async_op:
return _stub_work
else:
return None
return _comm_process_handle(handle, async_op)
def all_to_all_single(
output,
@@ -743,9 +685,6 @@ def all_to_all_single(
group=None,
async_op=False,
):
if async_op:
warning("all_to_all_single: 'async_op' not actually supported now. Run as sync op")
if output_split_sizes is not None or input_split_sizes is not None:
raise NotImplementedError("all_to_all_single not support output_split_sizes and input_split_sizes now.")
@@ -754,26 +693,17 @@ def all_to_all_single(
# _warn_not_in_group("all_to_all_single")
return None
_split_count = input.shape[0]
_split_dim = 0
_concat_dim = 0
if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim)
else:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim, _group_name)
input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
result, handle = _all_to_all_single_ms(cast_to_ms_tensor(output),
cast_to_ms_tensor(input),
output_split_sizes,
input_split_sizes,
_group_name,
async_op)
_comm_inplace_assign(output, result)
return _comm_process_handle(handle, async_op)
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
@@ -792,16 +722,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.operations._inner_ops.Reduce)(dst, op)
else:
# translate to world rank to group rank.
group_dst_rank = get_group_rank(group, dst)
_reduce_op = _get_cache_prim(ms.ops.operations._inner_ops.Reduce)(group_dst_rank, op, _group_name)
out = _reduce_op(tensor)
out = _reduce_ms(cast_to_ms_tensor(tensor), dst, op, _group_name)
if dst == get_rank():
tensor.data = out
_comm_inplace_assign(tensor, out)
if async_op:
return _stub_work
else:
@@ -824,16 +747,9 @@ def send(tensor, dst, group=None, tag=0):
if group is None:
group = _get_default_group()
# `dst` is global rank, but mindspore only accept local_rank
dst = get_group_rank(group, dst)
# TODO: ms.ops.operations._inner_ops.Send(group='hccl_world_group'), default is under Ascend.
# So have to pass '_group_name' as arg to support both GPU and Ascend.
_group_name = _get_pg_name(group)
_send_op = _get_cache_prim(ms.ops.operations._inner_ops.Send)(tag, dst, _group_name, _group_name)
out = _send_op(tensor)
# Additionly return 'out' to ensure bprop is correctly running.
return out
send_ms(tensor, dst, _group_name, tag)
return None
def recv(tensor, src=None, group=None, tag=0):
_inplace_raise_error_graph_mode('recv', 'recv_not_inplace')
@@ -851,20 +767,10 @@ def recv(tensor, src=None, group=None, tag=0):
if group is None:
group = _get_default_group()
# `src` is global rank, but mindspore only accept local_rank
src = get_group_rank(group, src)
# TODO: ms.ops.operations._inner_ops.Receive(group='hccl_world_group'), default is under Ascend.
# So have to pass '_group_name' as arg to support both GPU and Ascend.
_group_name = _get_pg_name(group)
_recv_op = _get_cache_prim(ms.ops.operations._inner_ops.Receive)(tag,
src,
list(tensor.shape),
tensor.dtype,
_group_name,
_group_name)
# Additionly pass 'tensor' to _recv_op to ensure bprop is correctly running.
out = _recv_op(tensor )
tensor.data = out
out = recv_ms(tensor, src, _group_name, tag)
_comm_inplace_assign(tensor, out)
return src
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
@@ -882,20 +788,15 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
if group is None:
group = _get_default_group()
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op)
else:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)
input_list = cast_to_ms_tensor(input_list)
input_ms = ms.ops.concat(input_list)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
out, handle = _reduce_scatter_ms(input_ms, op, _group_name, async_op)
_comm_inplace_assign(output, out)
return _comm_process_handle(handle, async_op)
def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
@@ -909,19 +810,14 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
# _warn_not_in_group("reduce_scatter_tensor")
return None
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op)
else:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)
input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
out, handle = _reduce_scatter_ms(input_ms, op, _group_name, async_op)
_comm_inplace_assign(output, out)
return _comm_process_handle(handle, async_op)
def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
return reduce_scatter_tensor(output, input, op, group, async_op)
@@ -949,18 +845,17 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.operations._inner_ops. CollectiveGather)(dst)
_op = _get_cache_prim(ms.ops.CollectiveGather)(dst)
else:
group_dst_rank = get_group_rank(group, dst)
_op = _get_cache_prim(ms.ops.operations._inner_ops. CollectiveGather)(group_dst_rank, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveGather)(group_dst_rank, _group_name)
out = _op(tensor)
if dst == my_rank:
_split_count = len(gather_list)
_spilit_size = out.shape[0] // _split_count
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(gather_list):
output.data = out[i]
_comm_inplace_assign(gather_list, out)
if async_op:
return _stub_work
else:
@@ -987,10 +882,10 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.operations._inner_ops. CollectiveScatter)(src)
_op = _get_cache_prim(ms.ops.CollectiveScatter)(src)
else:
group_src_rank = get_group_rank(group, src)
_op = _get_cache_prim(ms.ops.operations._inner_ops. CollectiveScatter)(group_src_rank, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveScatter)(group_src_rank, _group_name)
my_rank = get_rank()
if my_rank == src:
@@ -1010,12 +905,61 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
group_size = _get_group_size(group)
input_ms = ms.ops.zeros((group_size,) + tensor.shape, dtype=tensor.dtype)
out = _op(input_ms)[0]
tensor.data = out
_comm_inplace_assign(tensor, out)
if async_op:
return _stub_work
else:
return None
def isend(tensor, dst, group=None, tag=0):
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("isend")
return None
if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
handle = _isend_ms(cast_to_ms_tensor(tensor), dst, _group_name, tag)
return _comm_process_handle(handle, True)
def irecv(tensor, src=None, group=None, tag=0):
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("irecv")
return None
if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
result, handle = _irecv_ms(cast_to_ms_tensor(tensor), src, _group_name, tag)
_comm_inplace_assign(tensor, result)
return _comm_process_handle(handle, True)
class P2POp(_P2POp_ms):
...
def batch_isend_irecv(p2p_op_list):
# use copy to avoid being modified by mindspore in place.
p2p_op_list_copy = copy.copy(p2p_op_list)
for i, p2p_op in enumerate(p2p_op_list_copy):
if p2p_op.group is None:
p2p_op.group = GlobalComm.WORLD_COMM_GROUP
else:
p2p_op.group = _get_pg_name(p2p_op.group)
result = _batch_isend_irecv_ms(p2p_op_list_copy)
reqs = []
for i, p2p_op in enumerate(p2p_op_list):
if p2p_op.op.__name__ == "irecv":
_comm_inplace_assign(p2p_op.tensor, result[i])
reqs.append(_stub_work)
return reqs
@contextlib.contextmanager
def _coalescing_manager(group, device, reqs):
unsupported_attr(group)