#976 suit-2.4.0-distributed

Merged
Erpim merged 9 commits from lhy-2.4.0-distributed into master 4 months ago
  1. +1
    -21
      mindtorch/torch/autograd/function.py
  2. +48
    -52
      mindtorch/torch/distributed/_distributed_c10d.py
  3. +131
    -187
      mindtorch/torch/distributed/distributed_c10d.py
  4. +1
    -0
      testing/st/pytorch/distributed/all_to_all_single_impl.py
  5. +28
    -0
      testing/st/pytorch/distributed/batch_isend_irecv_impl.py
  6. +11
    -1
      testing/st/pytorch/distributed/test_dist_interface.py
  7. +215
    -0
      testing/ut/pytorch/autograd/test_auto_grad_function.py

+ 1
- 21
mindtorch/torch/autograd/function.py View File

@@ -1,7 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from types import FunctionType, MethodType

from mindtorch.utils import unsupported_attr
from mindtorch.torch.nn import Module
from mindtorch.torch.tensor import Tensor, cast_to_adapter_tensor, cast_to_ms_tensor
@@ -63,24 +61,6 @@ class Function(Module):
unsupported_attr(kwargs)
super(Function, self).__init__()
self.ctx = FunctionCtx()
def compile_bprop(num_input):
input_args = ""
for i in range(num_input):
if i < num_input - 1:
input_args += f"input{i},"
else:
input_args += f"input{i}"
input_args += ",out,dout"
input_args_with_self = "self," + input_args
code = f"""def bprop({input_args_with_self}): return self._backward_wrapper({input_args})"""
code = compile(code, "<string>", "exec")
func = FunctionType(code.co_consts[0], globals(), "bprop")
return func

if not hasattr(self, 'bprop'):
# num_input should remove ctx input, so minus 1.
num_input = self.forward.__code__.co_argcount - 1
self.bprop = MethodType(compile_bprop(num_input), self)

@staticmethod
def forward(ctx, *args, **kwargs):
@@ -104,7 +84,7 @@ class Function(Module):
"your custom autograd.Function to use it with backward "
"mode AD.")

def _backward_wrapper(self, *args, **kwargs):
def bprop(self, *args, **kwargs):
unsupported_attr(kwargs)
# Prev node may be a mindspore bprop node, and type of grad_outputs may be MindSpore Tensor.
# But in backward, the computation will treat it as a MindTorch Tensor.


+ 48
- 52
mindtorch/torch/distributed/_distributed_c10d.py View File

@@ -4,7 +4,7 @@ from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.communication.management import GlobalComm
from mindspore.communication._comm_helper import _get_group_ranks

from mindtorch.torch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor
from mindtorch.torch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor, Tensor
from mindtorch.utils import unsupported_attr

__all__ = ['ReduceOp', 'ProcessGroup', 'ProcessGroupNCCL', 'BroadcastOptions', 'AllreduceOptions',
@@ -68,16 +68,17 @@ def _get_pg_name(group=None):

@ms.jit_class
class Work:
def __init__(self):
def __init__(self, handle=None):
# Do nothings here for now. After mindspore stream support, use stream to create work.
self._result = None
self._handle = handle

def start(self):
# Do nothings here for now. After mindspore stream support, use stream to create work.
...
def wait(self):
# Do nothings here for now. After mindspore stream support, use stream to create work.
...
if self._handle:
self._handle.wait()

def result(self):
return self._result
@@ -204,11 +205,9 @@ class ProcessGroup(str):
op = _get_str_from_reduceop(op)
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
tensor = tensors[0]
work = _create_work()
work.start()
result = _reduce_op(tensor)
result, handle = _reduce_op(tensor)
tensor.data = result
return work
return Work(handle)

def reduce(
self,
@@ -223,13 +222,11 @@ class ProcessGroup(str):
op = root.reduceOp
root = root.rootRank
_group_name = _get_pg_name(self)
_reduce_op = _get_cache_prim(ms.ops.operations._inner_ops.Reduce)(root, op, _group_name)
work = _create_work()
work.start()
_reduce_op = _get_cache_prim(ms.ops.Reduce)(root, op, _group_name)
tensor = tensors[0]
out = _reduce_op(tensor)
tensor.data = out
return work
out, handle = _reduce_op(tensor)
_comm_inplace_assign(tensor, out)
return Work(handle)

def _allgather_base(
self,
@@ -241,11 +238,9 @@ class ProcessGroup(str):
unsupported_attr(opts)
_group_name = _get_pg_name(self)
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)
work = _create_work()
work.start()
result = _ag_op(input)
output.data = result
return work
result, handle = _ag_op(input)
_comm_inplace_assign(output, result)
return Work(handle)

def allgather(
self,
@@ -258,14 +253,11 @@ class ProcessGroup(str):
_group_size = ms.communication.get_group_size(_group_name)
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)
_split_op = _get_cache_prim(ms.ops.Split)(0, _group_size)
work = _create_work()
work.start()
input_tensor = input_tensor[0]
result = _ag_op(input_tensor)
result, handle = _ag_op(input_tensor)
result = _split_op(result)
for i, _tensor in enumerate(output_tensors):
_tensor.data = result[i]
return work
_comm_inplace_assign(output_tensors, result)
return Work(handle)

def gather(
self,
@@ -278,7 +270,7 @@ class ProcessGroup(str):
if isinstance(root, GatherOptions):
root = root.rootRank
_group_name = _get_pg_name(self)
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveGather)(root, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveGather)(root, _group_name)
work = _create_work()
work.start()
out = _op(input_tensor)
@@ -303,7 +295,7 @@ class ProcessGroup(str):
if isinstance(root, ScatterOptions):
root = root.rootRank
_group_name = _get_pg_name(self)
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveScatter)(root, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveScatter)(root, _group_name)
work = _create_work()
work.start()

@@ -338,20 +330,18 @@ class ProcessGroup(str):
op = _get_str_from_reduceop(op)
_group_name = _get_pg_name(self)
_reduce_scatter_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)
work = _create_work()
work.start()
input_tensor = cast_to_ms_tensor(input_tensor)
for i, tensor in enumerate(input_tensor):
if tensor.ndim == 0:
_zero_ndim = 1
input_tensor[i] = tensor.expand_dims(0)
input_ms = ms.ops.concat(input_tensor)
out = _reduce_scatter_op(input_ms)
out, handle = _reduce_scatter_op(input_ms)
if not _zero_ndim:
output_tensors.data = out
else:
output_tensors.data = out[0]
return work
return Work(handle)

def _reduce_scatter_base(
self,
@@ -364,11 +354,9 @@ class ProcessGroup(str):
input_ms = cast_to_ms_tensor(input)
if input_ms.ndim <= 1:
input_ms.expand_dims(-1)
work = _create_work()
work.start()
out = _reduce_scatter_op(input_ms)
output.data = out
return work
out, handle = _reduce_scatter_op(input_ms)
_comm_inplace_assign(output, out)
return Work(handle)


def alltoall_base(
@@ -398,11 +386,9 @@ class ProcessGroup(str):
if len(tensors) > 1:
raise NotImplementedError('ProcessGroup.send not support list of tensors yet.')
_group_name = _get_pg_name(self)
_send_op = _get_cache_prim(ms.ops.operations._inner_ops.Send)(tag, dstRank, _group_name)
work = _create_work()
work.start()
_send_op(tensors[0])
return work
_send_op = _get_cache_prim(ms.ops.Send)(tag, dstRank, _group_name)
_, handle = _send_op(tensors[0])
return Work(handle)

def recv(
self,
@@ -414,16 +400,14 @@ class ProcessGroup(str):
if len(tensors) > 1:
raise NotImplementedError('ProcessGroup.recv not support list of tensors yet.')
tensor = tensors[0]
_recv_op = _get_cache_prim(ms.ops.operations._inner_ops.Receive)(tag,
srcRank,
list(tensor.shape),
tensor.dtype,
_group_name)
work = _create_work()
work.start()
out = _recv_op(tensor)
tensor.data = out
return work
_recv_op = _get_cache_prim(ms.ops.Receive)(tag,
srcRank,
list(tensor.shape),
tensor.dtype,
_group_name)
out, handle = _recv_op(tensor)
_comm_inplace_assign(tensor, out)
return Work(handle)

def barrier(self,
opts=BarrierOptions(),
@@ -433,7 +417,7 @@ class ProcessGroup(str):
work = _create_work()
work.start()
_group_name = _get_pg_name(self)
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)(_group_name)
_barrier_op = _get_cache_prim(ms.ops.Barrier)(_group_name)
_barrier_op()
return work

@@ -460,3 +444,15 @@ class ProcessGroupNCCL(ProcessGroup):
@staticmethod
def _group_end():
...

def _comm_inplace_assign(output, result):
if isinstance(output, Tensor):
output.assign_value(result)
else:
for out, res in zip(output, result):
out.assign_value(res)

def _comm_process_handle(handle, async_op):
if async_op:
return Work(handle)
return None

+ 131
- 187
mindtorch/torch/distributed/distributed_c10d.py View File

@@ -1,4 +1,5 @@
import contextlib
import copy

from mindspore.communication.management import init, create_group, GlobalComm, destroy_group
from mindspore.communication._comm_helper import (_is_available, _is_initialized, _is_hccl_available,
@@ -6,6 +7,21 @@ from mindspore.communication._comm_helper import (_is_available, _is_initialized

import mindspore as ms
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.communication.comm_func import all_to_all_with_output_shape as _all_to_all_ms, \
all_to_all_single_with_output_shape as _all_to_all_single_ms, \
P2POp as _P2POp_ms, \
batch_isend_irecv as _batch_isend_irecv_ms, \
isend as _isend_ms, \
irecv as _irecv_ms, \
barrier as _barrier_ms, \
broadcast as _broadcast_ms, \
all_reduce as _all_reduce_ms, \
reduce as _reduce_ms, \
reduce_scatter_tensor as _reduce_scatter_ms, \
all_gather_into_tensor as _all_gather_into_tensor_ms, \
send as send_ms, \
recv as recv_ms


from mindtorch.torch.common.dtype import int8, int32, float16, float32, bfloat16, all_complex_type
from mindtorch.utils import unsupported_attr, graph_mode_condition, is_under_ascend_context, \
@@ -33,7 +49,9 @@ from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0
_get_str_from_reduceop,
_pg_map,
_pg_names,
_group_count
_group_count,
_comm_inplace_assign,
_comm_process_handle
)

_ascend_support_dtype = (int8, int32, float16, float32, bfloat16)
@@ -510,18 +528,10 @@ def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op)
else:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
result = _reduce_op(tensor)
result = _all_reduce_ms(cast_to_ms_tensor(tensor), _group_name)
return cast_to_adapter_tensor(result)

def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
warning("all_reduce: 'async_op' not actually supported now. Run as sync op")
_inplace_raise_error_graph_mode('all_reduce', 'all_reduce_not_inplace')

_check_single_tensor(tensor, "tensor")

if _rank_not_in_group(group):
@@ -534,26 +544,17 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op)
else:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _reduce_op(cast_tensor)
result, handle = _all_reduce_ms(cast_to_ms_tensor(cast_tensor), op, _group_name, async_op)
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _reduce_op(tensor)
tensor.data = result
if async_op:
return _stub_work
else:
return None
result, handle = _all_reduce_ms(cast_to_ms_tensor(tensor), op, _group_name, async_op)
_comm_inplace_assign(tensor, result)

def broadcast(tensor, src, group=None, async_op=False):
if async_op:
warning("broadcast: 'async_op' not actually supported now. Run as sync op")
return _comm_process_handle(handle, async_op)

def broadcast(tensor, src, group=None, async_op=False):
_inplace_raise_error_graph_mode('broadcast', 'broadcast_not_inplace')

_check_single_tensor(tensor, "tensor")
@@ -567,48 +568,21 @@ def broadcast(tensor, src, group=None, async_op=False):
group = _get_default_group()

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src)
else:
src = get_group_rank(group, src)
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src, _group_name)
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _bc_op((cast_tensor,))[0]
result = _broadcast_ms(cast_to_ms_tensor(cast_tensor), src, _group_name)
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _bc_op((tensor,))[0]
tensor.data = result
result = _broadcast_ms(cast_to_ms_tensor(tensor), src, _group_name)
_comm_inplace_assign(tensor, result)
if async_op:
return _stub_work
else:
return None

def all_gather_into_tensor_not_inplace(input_tensor, group=None, async_op=False):
if async_op:
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")

if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("broadcast")
return None

if group is None:
group = _get_default_group()

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_ag_op = _get_cache_prim(ms.ops.AllGather)()
else:
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)
return cast_to_adapter_tensor(_ag_op(input_tensor))

def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
_inplace_raise_error_graph_mode('all_gather_into_tensor', 'all_gather_into_tensor_not_inplace')

if async_op:
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(input_tensor, "input_tensor")
_check_single_tensor(output_tensor, "output_tensor")

@@ -621,12 +595,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
group = _get_default_group()

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_ag_op = _get_cache_prim(ms.ops.AllGather)()
else:
_ag_op = _get_cache_prim(ms.ops.AllGather)(_group_name)

result =_ag_op(input_tensor)
result, handle = _all_gather_into_tensor_ms(cast_to_ms_tensor(input_tensor), _group_name, async_op)

_out_tensor_prim_size = output_tensor.shape[0]
_result_prim_size = result.shape[0]
@@ -639,25 +609,12 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
split_size = _result_prim_size // _out_tensor_prim_size
result = ms.ops.split(result, split_size)
result = ms.ops.stack(result)
output_tensor.data = result
if async_op:
return _stub_work
else:
return None
_comm_inplace_assign(output_tensor, result)
return _comm_process_handle(handle, async_op)

def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)

def all_gather_not_inplace(tensor, group=None, async_op=False):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_gather")
return None
result = all_gather_into_tensor_not_inplace(tensor, group, async_op)
group_size = _get_group_size(group)
_split_op = _get_cache_prim(ms.ops.Split)(0, group_size)
return cast_to_adapter_tensor(_split_op(result))

def all_gather(tensor_list, tensor, group=None, async_op=False):
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
@@ -666,13 +623,17 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
# Graph mode not support code below.
# _warn_not_in_group("all_gather")
return None
result = all_gather_not_inplace(tensor, group, async_op)
for i, _tensor in enumerate(tensor_list):
_tensor.data = result[i]
if async_op:
return _stub_work
else:
return None

if group is None:
group = _get_default_group()
_group_name = _get_pg_name(group)
result, handle = _all_gather_into_tensor_ms(cast_to_ms_tensor(tensor), _group_name, async_op)

group_size = _get_group_size(group)
_split_op = _get_cache_prim(ms.ops.Split)(0, group_size)
result = _split_op(result)
_comm_inplace_assign(tensor_list, result)
return _comm_process_handle(handle, async_op)

def barrier(group=None, async_op=False, device_ids=None):
if async_op:
@@ -690,50 +651,31 @@ def barrier(group=None, async_op=False, device_ids=None):
group = _get_default_group()

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)()
else:
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)(_group_name)
_barrier_op()
_barrier_ms(_group_name)
if async_op:
return _stub_work
else:
return None

def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
if async_op:
warning("all_to_all: 'async_op' not actually supported now. Run as sync op")

if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_to_all_single")
# _warn_not_in_group("all_to_all")
return None

if group is None:
group = _get_default_group()

_split_dim = 0
_concat_dim = 0
_split_count = len(input_tensor_list)
_group_name = _get_pg_name(group)

input_tensor_list = cast_to_ms_tensor(input_tensor_list)
_input_tensor = ms.ops.stack(input_tensor_list)
result, handle = _all_to_all_ms(cast_to_ms_tensor(output_tensor_list),
cast_to_ms_tensor(input_tensor_list),
_group_name,
async_op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim)
else:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim, _group_name)
_comm_inplace_assign(output_tensor_list, result)

out = _op(_input_tensor)
_spilit_size = out.shape[0] // _split_count
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(output_tensor_list):
output.data = out[i]
if async_op:
return _stub_work
else:
return None
return _comm_process_handle(handle, async_op)

def all_to_all_single(
output,
@@ -743,9 +685,6 @@ def all_to_all_single(
group=None,
async_op=False,
):
if async_op:
warning("all_to_all_single: 'async_op' not actually supported now. Run as sync op")

if output_split_sizes is not None or input_split_sizes is not None:
raise NotImplementedError("all_to_all_single not support output_split_sizes and input_split_sizes now.")

@@ -754,26 +693,17 @@ def all_to_all_single(
# _warn_not_in_group("all_to_all_single")
return None

_split_count = input.shape[0]
_split_dim = 0
_concat_dim = 0

if group is None:
group = _get_default_group()

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim)
else:
_op = _get_cache_prim(ms.ops.AlltoAll)(_split_count, _split_dim, _concat_dim, _group_name)

input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
result, handle = _all_to_all_single_ms(cast_to_ms_tensor(output),
cast_to_ms_tensor(input),
output_split_sizes,
input_split_sizes,
_group_name,
async_op)

_comm_inplace_assign(output, result)
return _comm_process_handle(handle, async_op)

def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
@@ -792,16 +722,9 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
op = _get_str_from_reduceop(op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_reduce_op = _get_cache_prim(ms.ops.operations._inner_ops.Reduce)(dst, op)
else:
# translate to world rank to group rank.
group_dst_rank = get_group_rank(group, dst)
_reduce_op = _get_cache_prim(ms.ops.operations._inner_ops.Reduce)(group_dst_rank, op, _group_name)

out = _reduce_op(tensor)
out = _reduce_ms(cast_to_ms_tensor(tensor), dst, op, _group_name)
if dst == get_rank():
tensor.data = out
_comm_inplace_assign(tensor, out)
if async_op:
return _stub_work
else:
@@ -824,16 +747,9 @@ def send(tensor, dst, group=None, tag=0):
if group is None:
group = _get_default_group()

# `dst` is global rank, but mindspore only accept local_rank
dst = get_group_rank(group, dst)
# TODO: ms.ops.operations._inner_ops.Send(group='hccl_world_group'), default is under Ascend.
# So have to pass '_group_name' as arg to support both GPU and Ascend.
_group_name = _get_pg_name(group)
_send_op = _get_cache_prim(ms.ops.operations._inner_ops.Send)(tag, dst, _group_name, _group_name)
out = _send_op(tensor)

# Additionly return 'out' to ensure bprop is correctly running.
return out
send_ms(tensor, dst, _group_name, tag)
return None

def recv(tensor, src=None, group=None, tag=0):
_inplace_raise_error_graph_mode('recv', 'recv_not_inplace')
@@ -851,20 +767,10 @@ def recv(tensor, src=None, group=None, tag=0):
if group is None:
group = _get_default_group()

# `src` is global rank, but mindspore only accept local_rank
src = get_group_rank(group, src)
# TODO: ms.ops.operations._inner_ops.Receive(group='hccl_world_group'), default is under Ascend.
# So have to pass '_group_name' as arg to support both GPU and Ascend.
_group_name = _get_pg_name(group)
_recv_op = _get_cache_prim(ms.ops.operations._inner_ops.Receive)(tag,
src,
list(tensor.shape),
tensor.dtype,
_group_name,
_group_name)
# Additionly pass 'tensor' to _recv_op to ensure bprop is correctly running.
out = _recv_op(tensor)
tensor.data = out
out = recv_ms(tensor, src, _group_name, tag)
_comm_inplace_assign(tensor, out)
return src

def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
@@ -882,20 +788,15 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
if group is None:
group = _get_default_group()

op = _get_str_from_reduceop(op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op)
else:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)

input_list = cast_to_ms_tensor(input_list)
input_ms = ms.ops.concat(input_list)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
out, handle = _reduce_scatter_ms(input_ms, op, _group_name, async_op)
_comm_inplace_assign(output, out)
return _comm_process_handle(handle, async_op)

def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
@@ -909,19 +810,14 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
# _warn_not_in_group("reduce_scatter_tensor")
return None

op = _get_str_from_reduceop(op)

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op)
else:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)

input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None
out, handle = _reduce_scatter_ms(input_ms, op, _group_name, async_op)
_comm_inplace_assign(output, out)
return _comm_process_handle(handle, async_op)

def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
return reduce_scatter_tensor(output, input, op, group, async_op)
@@ -949,18 +845,17 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveGather)(dst)
_op = _get_cache_prim(ms.ops.CollectiveGather)(dst)
else:
group_dst_rank = get_group_rank(group, dst)
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveGather)(group_dst_rank, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveGather)(group_dst_rank, _group_name)
out = _op(tensor)

if dst == my_rank:
_split_count = len(gather_list)
_spilit_size = out.shape[0] // _split_count
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(gather_list):
output.data = out[i]
_comm_inplace_assign(gather_list, out)
if async_op:
return _stub_work
else:
@@ -987,10 +882,10 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveScatter)(src)
_op = _get_cache_prim(ms.ops.CollectiveScatter)(src)
else:
group_src_rank = get_group_rank(group, src)
_op = _get_cache_prim(ms.ops.operations._inner_ops.CollectiveScatter)(group_src_rank, _group_name)
_op = _get_cache_prim(ms.ops.CollectiveScatter)(group_src_rank, _group_name)

my_rank = get_rank()
if my_rank == src:
@@ -1010,12 +905,61 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
group_size = _get_group_size(group)
input_ms = ms.ops.zeros((group_size,) + tensor.shape, dtype=tensor.dtype)
out = _op(input_ms)[0]
tensor.data = out
_comm_inplace_assign(tensor, out)
if async_op:
return _stub_work
else:
return None

def isend(tensor, dst, group=None, tag=0):
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("isend")
return None

if group is None:
group = _get_default_group()

_group_name = _get_pg_name(group)
handle = _isend_ms(cast_to_ms_tensor(tensor), dst, _group_name, tag)
return _comm_process_handle(handle, True)


def irecv(tensor, src=None, group=None, tag=0):
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
_warn_not_in_group("irecv")
return None

if group is None:
group = _get_default_group()

_group_name = _get_pg_name(group)
result, handle = _irecv_ms(cast_to_ms_tensor(tensor), src, _group_name, tag)
_comm_inplace_assign(tensor, result)
return _comm_process_handle(handle, True)


class P2POp(_P2POp_ms):
...

def batch_isend_irecv(p2p_op_list):
# use copy to avoid being modified by mindspore in place.
p2p_op_list_copy = copy.copy(p2p_op_list)
for i, p2p_op in enumerate(p2p_op_list_copy):
if p2p_op.group is None:
p2p_op.group = GlobalComm.WORLD_COMM_GROUP
else:
p2p_op.group = _get_pg_name(p2p_op.group)

result = _batch_isend_irecv_ms(p2p_op_list_copy)
reqs = []
for i, p2p_op in enumerate(p2p_op_list):
if p2p_op.op.__name__ == "irecv":
_comm_inplace_assign(p2p_op.tensor, result[i])
reqs.append(_stub_work)
return reqs

@contextlib.contextmanager
def _coalescing_manager(group, device, reqs):
unsupported_attr(group)


+ 1
- 0
testing/st/pytorch/distributed/all_to_all_single_impl.py View File

@@ -18,6 +18,7 @@ def func(backend):

if rank == 0:
assert np.allclose(data.cpu().numpy(), np.array([1, 3.]))
else:
assert np.allclose(data.cpu().numpy(), np.array([2, 4.]))

if __name__ == '__main__':


+ 28
- 0
testing/st/pytorch/distributed/batch_isend_irecv_impl.py View File

@@ -0,0 +1,28 @@
import sys
import numpy as np

import mindtorch.torch as torch
import mindtorch.torch.distributed as dist

def func(backend):
dist.init_process_group(backend)

rank = dist.get_rank()
world_size = dist.get_world_size()

send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
recv_tensor = torch.randn(2, dtype=torch.float32)
send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1)%world_size)
recv_op = dist.P2POp(dist.irecv, recv_tensor, (rank - 1 + world_size)%world_size)
reqs = dist.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
if rank == 0:
assert np.allclose(recv_tensor.numpy(), np.array([2, 3]).astype(np.float32))
elif rank == 1:
assert np.allclose(recv_tensor.numpy(), np.array([0, 1]).astype(np.float32))


if __name__ == '__main__':
backend = sys.argv[1]
func(backend)

+ 11
- 1
testing/st/pytorch/distributed/test_dist_interface.py View File

@@ -235,7 +235,6 @@ def test_reduce_scatter_grad():
raise Exception(e.output.decode())

@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.')
@SKIP_ENV_PYNATIVE_MODE(reason='MindSpore not support alltoall under pynative mode.')
@SKIP_ENV_CPU(reason='distribute op not supported on CPU')
@SKIP_ENV_GPU(reason='MindSpore not support alltoall on GPU')
def test_all_to_all():
@@ -372,3 +371,14 @@ def test_allreduce_dtype():
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.')
@SKIP_ENV_CPU(reason='distribute op not supported on CPU')
def test_batch_isend_irecv():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 2 '
cmd += 'python {}/batch_isend_irecv_impl.py {}'.format(cur_dir, backend)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

+ 215
- 0
testing/ut/pytorch/autograd/test_auto_grad_function.py View File

@@ -0,0 +1,215 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import torch
import mindspore as ms
from mindspore import context
import mindtorch.torch as ms_torch
from ...utils import set_mode_by_env_config
from mindtorch.torch.autograd import Function
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, enable_backward, param_compare

set_mode_by_env_config()

def adapter_autograd_function():
class Net(ms_torch.nn.Module):
def __init__(self):
super(Net, self).__init__()

def forward(self, x, y):
out = ms_torch.matmul(x, y)
return out

# bprop: https://www.mindspore.cn/tutorials/experts/zh-CN/r1.9/network/custom_cell_reverse.html
def bprop(self, x, y, out, dout):
dx = x + 1
dy = y + 1
return dx, dy

x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32)
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32)
net = Net()
out = net(x, y)
grad_out = ms.grad(net, grad_position=(0, 1))(x, y)
return out, grad_out


def torch_autograd_function():
class Net(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
result = torch.matmul(x, y)
ctx.save_for_backward(x, y)
return result

@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
dx = x + 1
dy = y + 1
return dx, dy

x = torch.tensor([[0.5, 0.6, 0.4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[0.01], [0.2], [3.3]], dtype=torch.float32, requires_grad=True)
out = Net.apply(x, y)
out.backward()
grad_out = x.grad, y.grad
return out, grad_out

# TODO: backward not support custom bprop yet.
# def adapter_autograd_function_backward():
# class Net(ms_torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, y):
# result = ms_torch.matmul(x, y)
# ctx.save_for_backward(x, y)
# return result

# @staticmethod
# def backward(ctx, grad_output):
# x, y = ctx.saved_tensors
# dx = x + 1
# dy = y + 1
# return dx, dy

# x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True)
# y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True)
# with enable_backward():
# out = Net.apply(x, y)
# out.backward()
# grad_out = x.grad, y.grad
# return out, grad_out

def test_autograd_funciton():
ms_out, ms_grad_out = adapter_autograd_function()
pt_out, pt_grad_out = torch_autograd_function()
assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy())
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

# TODO: backward not support custom bprop yet.
# @SKIP_ENV_GRAPH_MODE(reason="tensor.backward not support graphmode")
# def test_autograd_function_bprop_backward():
# ms_out, ms_grad_out = adapter_autograd_function_backward()
# pt_out, pt_grad_out = torch_autograd_function()
# assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy())
# assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
# assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode")
def test_autograd_function_grad():
def adapter_autograd_function_ms_grad():
class Net(ms_torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
result = ms_torch.matmul(x, y)
ctx.save_for_backward(x, y)
return result

@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
dx = x + 1
dy = y + 1
return dx, dy

x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True)
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True)
def _func(a, b):
out = Net.apply(a, b)
return out
out, grad_out = ms.value_and_grad(_func, grad_position=(0, 1))(x, y)
return out, grad_out
ms_out, ms_grad_out = adapter_autograd_function_ms_grad()
pt_out, pt_grad_out = torch_autograd_function()
assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy())
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode")
def test_autograd_function_grad_bias_None():
def ms_torch_func():
class TestFunction(ms_torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, bias, has_bias):
result = ms_torch.matmul(input1, input2)
ctx.save_for_backward(result, bias)
ctx.has_bias = has_bias

if has_bias:
result = result + bias

return result.sum()

@staticmethod
def backward(ctx, grad_outputs):
result, bias = ctx.saved_tensors
has_bias = ctx.has_bias
result = grad_outputs * result
if has_bias:
result = result + bias
# TODO: not support bias gradient auto-reducesum.
# return result + 1, result + 2, result + 3, None
return result + 1, result + 2, (result + 3).sum(dim=1), None

input1 = ms_torch.ones([8, 8])
input2 = ms_torch.ones([8, 8])
bias = ms_torch.ones([8])

input1.requires_grad = True
input2.requires_grad = True
bias.requires_grad = True

def _func(x, y, z):
return TestFunction.apply(x, y, z, True)
grads = ms.grad(_func, grad_position=(0, 1, 2))(input1, input2, bias)
return grads

def torch_fun():
class TestFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, bias, has_bias):
result = torch.matmul(input1, input2)
ctx.save_for_backward(result, bias)
ctx.has_bias = has_bias

if has_bias:
result = result + bias

return result.sum()

@staticmethod
def backward(ctx, grad_outputs):
result, bias = ctx.saved_tensors
has_bias = ctx.has_bias
result = grad_outputs * result
if has_bias:
result = result + bias
# return result + 1, result + 2, result + 3, None
return result + 1, result + 2, (result + 3).sum(dim=1), None

input1 = torch.ones([8, 8])
input2 = torch.ones([8, 8])
bias = torch.ones([8])

input1.requires_grad = True
input2.requires_grad = True
bias.requires_grad = True

def _func(x, y, z):
return TestFunction.apply(x, y, z, True)

res = _func(input1, input2, bias)
res.backward()
return input1.grad, input2.grad, bias.grad

ms_grad = ms_torch_func()
torch_grad = torch_fun()
param_compare(ms_grad, torch_grad)

if __name__ == '__main__':
set_mode_by_env_config()
test_autograd_funciton()
# test_autograd_function_bprop_backward()
test_autograd_function_grad()
test_autograd_function_grad_bias_None()

Loading…
Cancel
Save