#895 update autograd.Function and distribute api

Merged
zoulq merged 37 commits from frelam/MSAdapter:master0319 into master 1 month ago
  1. +95
    -10
      mindtorch/torch/autograd/function.py
  2. +6
    -0
      mindtorch/torch/distributed/__init__.py
  3. +190
    -32
      mindtorch/torch/distributed/distributed_c10d.py
  4. +37
    -0
      testing/st/pytorch/distributed/all_reduce_dtype_impl.py
  5. +32
    -0
      testing/st/pytorch/distributed/broadcast_impl_ascend_cast_and_async_impl.py
  6. +6
    -5
      testing/st/pytorch/distributed/ddp_impl_ascend.py
  7. +6
    -5
      testing/st/pytorch/distributed/ddp_impl_gpu.py
  8. +27
    -0
      testing/st/pytorch/distributed/reduce_scatter_tensor_impl.py
  9. +33
    -0
      testing/st/pytorch/distributed/test_dist_interface.py
  10. +145
    -22
      testing/ut/pytorch/autograd/test_autograd_function.py

+ 95
- 10
mindtorch/torch/autograd/function.py View File

@@ -1,7 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from types import FunctionType, MethodType

from mindtorch.utils import unsupported_attr
from mindtorch.torch.nn import Module
from mindtorch.torch.tensor import Tensor, cast_to_adapter_tensor, cast_to_ms_tensor
from mindtorch.torch.logging import warning


class FunctionCtx:
def save_for_backward(self, *tensors):
self.to_save = tensors

def save_for_forward(self, *tensors):
for tensor in tensors:
if not isinstance(tensor, Tensor) or tensor is None:
raise TypeError(
"save_for_forward expects all arguments to be tensors; you should "
"save non-tensors as attributes on ctx."
)

self.saved_for_forward = tensors

def mark_dirty(self, *args):
warning("ctx.mark_dirty do not actually take effect now.")
self.dirty_tensors = args

Erpim commented 1 month ago
Review
这些功能实际不生效?
frelam commented 1 month ago
Review
是的。 根据影响添加了warning或者报错。
def mark_non_differentiable(self, *args):
raise NotImplementedError("ctx.mark_non_differentiable not support yet.")

def set_materialize_grads(self, value):
if not value:
warning("ctx.set_materialize_grads(false) not actually take effect now.")
self.materialize_grads = value

@property
def saved_tensors(self):
return self.to_save


class Function(Module):
@@ -27,15 +62,65 @@ class Function(Module):
unsupported_attr(args)
unsupported_attr(kwargs)
super(Function, self).__init__()
self.ctx = FunctionCtx()
def compile_bprop(num_input):
input_args = ""
for i in range(num_input):
if i < num_input - 1:
input_args += f"input{i},"
else:
input_args += f"input{i}"
input_args += ",out,dout"
input_args_with_self = "self," + input_args
code = f"""def bprop({input_args_with_self}): return self._backward_wrapper({input_args})"""
code = compile(code, "<string>", "exec")
func = FunctionType(code.co_consts[0], globals(), "bprop")
return func

def apply(self, *args, **kwargs):
zoulq commented 2 months ago
Review
原来使用这个接口会报错提示用mindspore对应接口,现在不会提示但会在brop入参的地方报错,用户应该是看不懂的,所有在用户资料里要更新一下自定义算子章节,另外FAQ加个样例说明。
frelam commented 1 month ago
Review
用动态生成函数的方式, 在__init__阶段自动生成了bprop, 当前用法可以与pytorch相同了。
zoulq commented 1 month ago
Review
cell的这个功能后面会优化
"""
# Don not use it by calling the apply method.
"""
unsupported_attr(args)
if not hasattr(self, 'bprop'):
# num_input should remove ctx input, so minus 1.
num_input = self.forward.__code__.co_argcount - 1
self.bprop = MethodType(compile_bprop(num_input), self)

@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError("You must implement the forward function for custom"
" autograd.Function.")

@classmethod
def apply(cls, *args, **kwargs):
obj = cls()
return obj(*args, **kwargs)

def construct(self, *args, **kwargs):
return self.forward(self.ctx, *args, **kwargs)

def _run_construct(self, cast_inputs, kwargs):
return self.forward(self.ctx, *cast_inputs, **kwargs)

@staticmethod
def backward(ctx, *grad_outputs):
raise NotImplementedError("You must implement either the backward method for "
"your custom autograd.Function to use it with backward "
"mode AD.")

def _backward_wrapper(self, *args, **kwargs):
unsupported_attr(kwargs)
raise RuntimeError("To create a custom autograd.Function, please use 'def forward(self, ...)' and "
"'def bprop(self, ..., out, dout)' instead of 'forward()' and 'backward()' static methods. "
"Then, use it as normal module class, do not call the class method 'apply'."
"Please refer to the following example: https://openi.pcl.ac.cn/OpenI/MSAdapter/src/"
"branch/master/doc/torch/USER_GUIDE.md#user-content-4-2-1-%E8%87%AA%E5%AE%9A%E4%B9%89module")
# Prev node may be a mindspore bprop node, and type of grad_outputs may be MindSpore Tensor.
# But in backward, the computation will treat it as a MindTorch Tensor.
# So add "cast_to_adapter_tensor" to ensure self.backward get a MindTorch Tensor.
grad_outputs = cast_to_adapter_tensor(args[-1])
if isinstance(grad_outputs, (list, tuple)):
res = self.backward(self.ctx, *grad_outputs)
else:
res = self.backward(self.ctx, grad_outputs)

# Next Node may be a MindSpore bprop node, so need to "cast_to_ms_tensor"
# to ensure next node get a MindSpore Tensor
if res is None:
res = 0
elif isinstance(res, (list, tuple)):
res = tuple(0 if x is None else cast_to_ms_tensor(x) for x in res)
else:
res = cast_to_ms_tensor(res)
return res

+ 6
- 0
mindtorch/torch/distributed/__init__.py View File

@@ -1,2 +1,8 @@
from .distributed_c10d import *
from ._distributed_c10d import *
from .distributed_c10d import (
_backend,
_all_gather_base,
_reduce_scatter_base,
_rank_not_in_group,
)

+ 190
- 32
mindtorch/torch/distributed/distributed_c10d.py View File

@@ -7,7 +7,9 @@ from mindspore.communication._comm_helper import (_is_available, _is_initialized
import mindspore as ms
from mindspore.ops._primitive_cache import _get_cache_prim

from mindtorch.utils import unsupported_attr, graph_mode_condition
from mindtorch.torch.common.dtype import int8, int32, float16, float32, bfloat16, all_complex_type
from mindtorch.utils import unsupported_attr, graph_mode_condition, is_under_ascend_context, \
is_under_gpu_context
from mindtorch.torch.logging import warning
from mindtorch.torch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor, Tensor
from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0611
@@ -34,6 +36,68 @@ from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0
_group_count
)

_ascend_support_dtype = (int8, int32, float16, float32, bfloat16)

_ascend_dtype_convert_map = {
'bool': int8,
'int64': int32,
'float64': float32,
}

# should use after cast_to_ms_tensor
def _check_and_convert_dtype_on_ascend(input_ms):
def _convert(tensor):
_origin_dtype = tensor.dtype
if _origin_dtype in all_complex_type:
raise TypeError("Not support communication of complex tensor yet.")
if _origin_dtype not in _ascend_support_dtype:
tensor = tensor.astype(
_ascend_dtype_convert_map[str(_origin_dtype).split('.')[-1]]
)
else:
_origin_dtype = None
return tensor, _origin_dtype

if isinstance(input_ms, Tensor):
return _convert(input_ms)
elif isinstance(input_ms, tuple):
input_ms = list(input_ms)
elif not isinstance(input_ms, list):
raise TypeError("input_ms must be type of Tensor, tuple or list")

_origin_dtype_list = []
for i, tensor in enumerate(input_ms):
converted_tensor, origin_dtype = _convert(tensor)
input_ms[i] = converted_tensor
_origin_dtype_list.append(origin_dtype)

zoulq commented 2 months ago
Review
什么场景下会用到这两个类型转换接口?
frelam commented 1 month ago
Review
Ascend上, 用户输入tensor的dtype, 在mindspore通信算子侧不支持时, 会用到该类型转换。
return input_ms, _origin_dtype_list

# should use before cast_to_adapter_tensor
def _recorver_dtype_on_ascend(output_ms, _origin_dtype):
def _recover(tensor, origin_dtype):
if origin_dtype is not None:
tensor = tensor.astype(origin_dtype)
return tensor

if isinstance(output_ms, Tensor):
return _recover(output_ms, _origin_dtype)
elif isinstance(output_ms, tuple):
output_ms = list(output_ms)
elif not isinstance(output_ms, list):
raise TypeError(f"output_ms must be type of Tensor, tuple or list, but got {type(output_ms)}.")

if not isinstance(_origin_dtype, list):
raise TypeError(f"_origin_dtype must be type of list, but got {type(_origin_dtype)}.")
if len(output_ms) != len(_origin_dtype):
raise ValueError("length of output_ms not equal to _origin_dtype")

for i, (output, dtype) in enumerate(zip(output_ms, _origin_dtype)):
if dtype is not None:
output_ms[i] = _recover(output, dtype)

return output_ms


BACKEND_DEVICE_TARGET_DICT = {
'mccl': 'CPU',
@@ -51,6 +115,8 @@ def _make_nccl_premul_sum(factor):
raise NotImplementedError('distributed._make_nccl_premul_sum not support yet.'
'Please manually scale the tensor before reduce.')

_stub_work = Work()

class Backend:
UNDEFINED = "undefined"
HCCL = "hccl"
@@ -75,6 +141,9 @@ class Backend:
raise NotImplementedError("For distributed.Backend, register_backend has not been supported yet. "
"For now, only 'hccl', 'nccl', 'mccl' are supported.")

_backend = Backend.UNDEFINED
dist_backend = Backend

# jit_class for _World to support graph mode
@ms.jit_class
class _World:
@@ -225,6 +294,11 @@ def init_process_group(
init()
else:
backend = Backend(backend)
# Help user automatically shift between Ascend and GPU backend without changing user code.
if backend == 'nccl' and is_under_ascend_context():
backend = 'hccl'
elif backend == 'hccl' and is_under_gpu_context():
backend = 'nccl'
init(backend)

name = GlobalComm.WORLD_COMM_GROUP
@@ -256,15 +330,12 @@ def new_group(ranks=None,
timeout=None,
backend=None,
pg_options=None):
unsupported_attr(backend)
global _world

if timeout is not None:
raise NotImplementedError("distributed.new_group: timeout is not supported")

if backend is not None:
raise NotImplementedError("distributed.new_group: backend is not supported. "
"Heterogeneity in a single process is not supported.")

if pg_options is not None:
raise NotImplementedError("distributed.new_group: pg_options is not supported."
"Heterogeneity in a single process is not supported.")
@@ -299,8 +370,8 @@ def new_group(ranks=None,
pg = ProcessGroup(name=name)
create_group(name, ranks)

if not backend:
backend = get_backend()
# TODO: after support `backend` arg, use 'backend' rather than 'get_backend()'.
backend = get_backend()
_world.pg_names[pg] = name
_world.pg_map[pg] = (backend,)

@@ -392,6 +463,9 @@ def is_nccl_available():
def is_hccl_available():
return _is_hccl_available()

def is_gloo_available():
return False

def get_process_group_ranks(group):
if _rank_not_in_group(group):
_warn_not_in_group("get_process_group_ranks")
@@ -425,7 +499,7 @@ def get_global_rank(group, group_rank):

def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.all_reduce not support async yet.")
warning("all_reduce: 'async_op' not actually supported now. Run as sync op")

if _rank_not_in_group(group):
# Graph mode not support code below.
@@ -446,7 +520,7 @@ def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False):

def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.all_reduce not support async yet.")
warning("all_reduce: 'async_op' not actually supported now. Run as sync op")
_inplace_raise_error_graph_mode('all_reduce', 'all_reduce_not_inplace')

_check_single_tensor(tensor, "tensor")
@@ -454,7 +528,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_reduce")
return
return None
if group is None:
group = _get_default_group()

@@ -465,12 +539,21 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op)
else:
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name)
result = _reduce_op(tensor)
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _reduce_op(cast_tensor)
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _reduce_op(tensor)
tensor.data = result
if async_op:
return _stub_work
else:
return None

def broadcast(tensor, src, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.broadcast not support async yet.")
warning("broadcast: 'async_op' not actually supported now. Run as sync op")

_inplace_raise_error_graph_mode('broadcast', 'broadcast_not_inplace')

@@ -479,7 +562,7 @@ def broadcast(tensor, src, group=None, async_op=False):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("broadcast")
return
return None

if group is None:
group = _get_default_group()
@@ -490,11 +573,21 @@ def broadcast(tensor, src, group=None, async_op=False):
else:
src = get_group_rank(group, src)
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src, _group_name)
tensor.data = _bc_op((tensor,))[0]
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
result = _bc_op((cast_tensor,))[0]
result = _recorver_dtype_on_ascend(result, _origin_dtype)
else:
result = _bc_op((tensor,))[0]
tensor.data = result
if async_op:
return _stub_work
else:
return None

def all_gather_into_tensor_not_inplace(input_tensor, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.broadcast not support async yet.")
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")

if _rank_not_in_group(group):
# Graph mode not support code below.
@@ -515,7 +608,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
_inplace_raise_error_graph_mode('all_gather_into_tensor', 'all_gather_into_tensor_not_inplace')

if async_op:
raise NotImplementedError("distributed.broadcast not support async yet.")
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(input_tensor, "input_tensor")
_check_single_tensor(output_tensor, "output_tensor")
@@ -523,7 +616,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("broadcast")
return
return None

if group is None:
group = _get_default_group()
@@ -548,6 +641,10 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
result = ms.ops.split(result, split_size)
result = ms.ops.stack(result)
output_tensor.data = result
if async_op:
return _stub_work
else:
return None

def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
@@ -569,14 +666,18 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_gather")
return
return None
result = all_gather_not_inplace(tensor, group, async_op)
for i, _tensor in enumerate(tensor_list):
_tensor.data = result[i]
if async_op:
return _stub_work
else:
return None

def barrier(group=None, async_op=False, device_ids=None):
if async_op:
raise NotImplementedError("distributed.barrier not support async yet.")
warning("barrier: 'async_op' not actually supported now. Run as sync op")

if device_ids:
raise NotImplementedError("distributed.barrier not support device_ids yet.")
@@ -584,7 +685,7 @@ def barrier(group=None, async_op=False, device_ids=None):
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("barrier")
return
return None

if group is None:
group = _get_default_group()
@@ -595,15 +696,19 @@ def barrier(group=None, async_op=False, device_ids=None):
else:
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)(_group_name)
_barrier_op()
if async_op:
return _stub_work
else:
return None

def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.barrier not support async yet.")
warning("all_to_all: 'async_op' not actually supported now. Run as sync op")

if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_to_all_single")
return
return None

if group is None:
group = _get_default_group()
@@ -626,6 +731,10 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(output_tensor_list):
output.data = out[i]
if async_op:
return _stub_work
else:
return None

def all_to_all_single(
output,
@@ -636,7 +745,7 @@ def all_to_all_single(
async_op=False,
):
if async_op:
raise NotImplementedError("distributed.barrier not support async yet.")
warning("all_to_all_single: 'async_op' not actually supported now. Run as sync op")

if output_split_sizes is not None or input_split_sizes is not None:
raise NotImplementedError("all_to_all_single not support output_split_sizes and input_split_sizes now.")
@@ -644,7 +753,7 @@ def all_to_all_single(
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("all_to_all_single")
return
return None

_split_count = input.shape[0]
_split_dim = 0
@@ -662,17 +771,21 @@ def all_to_all_single(
input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None

def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.barrier not support async yet.")
warning("reduce: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(tensor, "tensor")

if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("reduce")
return
return None

if group is None:
group = _get_default_group()
@@ -690,6 +803,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
out = _reduce_op(tensor)
if dst == get_rank():
tensor.data = out
if async_op:
return _stub_work
else:
return None

def send(tensor, dst, group=None, tag=0):
if get_rank() == dst:
@@ -753,7 +870,7 @@ def recv(tensor, src=None, group=None, tag=0):

def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.barrier not support async yet.")
warning("reduce_scatter: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(output, "output")
_check_tensor_list(input_list, "input_list")
@@ -761,7 +878,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("reduce_scatter")
return
return None

if group is None:
group = _get_default_group()
@@ -776,10 +893,43 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
input_ms = ms.ops.concat(input_list)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None

def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
if async_op:
warning("reduce_scatter_tensor: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(output, "output")
_check_single_tensor(input, "input")

if _rank_not_in_group(group):
# Graph mode not support code below.
# _warn_not_in_group("reduce_scatter_tensor")
return None

_group_name = _get_pg_name(group)
if _group_name == GlobalComm.WORLD_COMM_GROUP:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op)
else:
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name)

input_ms = cast_to_ms_tensor(input)
out = _op(input_ms)
output.data = out
if async_op:
return _stub_work
else:
return None

def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False):
return reduce_scatter_tensor(output, input, op, group, async_op)

def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.gather not support async yet.")
warning("gather: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(tensor, "tensor")

@@ -790,7 +940,7 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):

if _rank_not_in_group(group):
# _warn_not_in_group("gather")
return
return None

my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, gather_list)
@@ -812,10 +962,14 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
out = ms.ops.split(out, _spilit_size)
for i, output in enumerate(gather_list):
output.data = out[i]
if async_op:
return _stub_work
else:
return None

def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
if async_op:
raise NotImplementedError("distributed.gather not support async yet.")
warning("scatter: 'async_op' not actually supported now. Run as sync op")

_check_single_tensor(tensor, "tensor")

@@ -827,7 +981,7 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):

if _rank_not_in_group(group):
# _warn_not_in_group("scatter")
return
return None

if group is None:
group = _get_default_group()
@@ -858,6 +1012,10 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
input_ms = ms.ops.zeros((group_size,) + tensor.shape, dtype=tensor.dtype)
out = _op(input_ms)[0]
tensor.data = out
if async_op:
return _stub_work
else:
return None

@contextlib.contextmanager
def _coalescing_manager(group, device, reqs):


+ 37
- 0
testing/st/pytorch/distributed/all_reduce_dtype_impl.py View File

@@ -0,0 +1,37 @@
import sys
import numpy as np

import mindtorch.torch as torch
import mindtorch.torch.distributed as dist

def func(backend):
_total_device = 2

dist.init_process_group(backend, world_size=_total_device)

rank = dist.get_rank()

w = torch.tensor([[1, 2., 3, 4.], [5, 6, 7, 8]])
x = torch.tensor([2., 2, 2, 2])

w_device = w[:, w.size(1) // _total_device * rank : w.size(1) // _total_device * (rank + 1)].to(rank)
x_device = x[x.size(0) // _total_device * rank : x.size(0) // _total_device * (rank + 1)].to(rank)

result = torch.matmul(w_device, x_device)

_enum_dtype = {torch.bool, torch.int64}

for dtype in _enum_dtype:
_tmp = result.to(dtype)
dist.all_reduce(_tmp)
assert _tmp.dtype == dtype
if dtype == torch.bool:
assert np.allclose(_tmp.cpu().numpy(), np.array([True, True]))
else:
expected_result = np.array([20, 52.])
assert np.allclose(_tmp.cpu().numpy(), expected_result)
assert _tmp.shape == expected_result.shape

if __name__ == '__main__':
backend = sys.argv[1]
func(backend)

+ 32
- 0
testing/st/pytorch/distributed/broadcast_impl_ascend_cast_and_async_impl.py View File

@@ -0,0 +1,32 @@
import sys
import numpy as np

import mindtorch.torch as torch
import mindtorch.torch.distributed as dist

def func(backend):
dist.init_process_group(backend)

rank = dist.get_rank()

if rank == 0:
data = torch.tensor([1, 2.]).to(f'cuda:{rank}')
else:
data = torch.zeros(2).to(f'cuda:{rank}')

_enum_dtype = {torch.bool, torch.int64}

for dtype in _enum_dtype:
_tmp = data.to(dtype)
work = dist.broadcast(_tmp, 0, async_op=True)
work.wait()
assert _tmp.dtype == dtype
if dtype == torch.bool:
assert np.allclose(_tmp.cpu().numpy(), np.array([True, True]))
else:
assert np.allclose(_tmp.cpu().numpy(), np.array([1, 2.]))
assert data.shape == (2,)

if __name__ == '__main__':
backend = sys.argv[1]
func(backend)

+ 6
- 5
testing/st/pytorch/distributed/ddp_impl_ascend.py View File

@@ -3,16 +3,17 @@ import numpy as np
import mindspore as ms
from mindspore import nn

import mindtorch.torch as torch
from mindtorch.torch.nn.parallel import DistributedDataParallel as DDP
from mindtorch.torch import distributed as dist


class NetWork(nn.Cell):
class NetWork(torch.nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.dense = nn.Dense(3, 3)
self.dense = torch.nn.Linear(3, 3)

def construct(self, x):
def forward(self, x):
return self.dense(x).sum()


@@ -22,7 +23,7 @@ def ddp_basic():
dist.init_process_group(backend='hccl', rank=-1, world_size=2)

network = NetWork()
opt = nn.Adam(network.trainable_params())
opt = torch.optim.Adam(network.parameters())
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False)

rank = dist.get_rank()
@@ -34,7 +35,7 @@ def ddp_basic():
else:
network_p = None

inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32))
inputs = torch.tensor(np.random.random((2, 3)).astype(np.float32))
for _ in range(1):
loss, grads = grad_fn(inputs)
grads = network.all_reduce(grads)


+ 6
- 5
testing/st/pytorch/distributed/ddp_impl_gpu.py View File

@@ -3,16 +3,17 @@ import numpy as np
import mindspore as ms
from mindspore import nn

import mindtorch.torch as torch
from mindtorch.torch.nn.parallel import DistributedDataParallel as DDP
from mindtorch.torch import distributed as dist


class NetWork(nn.Cell):
class NetWork(torch.nn.Module):
def __init__(self):
super(NetWork, self).__init__()
self.dense = nn.Dense(3, 3)
self.dense = torch.nn.Linear(3, 3)

def construct(self, x):
def forward(self, x):
return self.dense(x).sum()


@@ -22,7 +23,7 @@ def ddp_basic():
dist.init_process_group(backend='nccl', rank=-1, world_size=3)

network = NetWork()
opt = nn.Adam(network.trainable_params())
opt = torch.optim.Adam(network.parameters())
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False)

rank = dist.get_rank()
@@ -34,7 +35,7 @@ def ddp_basic():
else:
network_p = None

inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32))
inputs = torch.tensor(np.random.random((2, 3)).astype(np.float32))
for _ in range(1):
loss, grads = grad_fn(inputs)
grads = network.all_reduce(grads)


+ 27
- 0
testing/st/pytorch/distributed/reduce_scatter_tensor_impl.py View File

@@ -0,0 +1,27 @@
import sys
import numpy as np

import mindtorch.torch as torch
import mindtorch.torch.distributed as dist

def func(backend):
dist.init_process_group(backend)

rank = dist.get_rank()

# mindspore reduce_scatter not support int64
tensor = torch.arange(2, dtype=torch.float32).to(f'cuda:{rank}') + 2 * rank
output = torch.empty(1, dtype=torch.float32).to(f'cuda:{rank}')

dist.reduce_scatter_tensor(output, tensor)

if rank == 0:
assert np.allclose(output.cpu().numpy(), np.array([2]).astype(np.float32))
assert output.shape == (1,)
else:
assert np.allclose(output.cpu().numpy(), np.array([4]).astype(np.float32))
assert output.shape == (1,)

if __name__ == '__main__':
backend = sys.argv[1]
func(backend)

+ 33
- 0
testing/st/pytorch/distributed/test_dist_interface.py View File

@@ -339,3 +339,36 @@ def test_allgather_grad():
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.')
@SKIP_ENV_CPU(reason='distribute op not supported on CPU')
def test_reduce_scatter_tensor():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 2 '
cmd += 'python {}/reduce_scatter_tensor_impl.py {}'.format(cur_dir, backend)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.')
@SKIP_ENV_CPU(reason='distribute op not supported on CPU')
def test_broadcast_cast_async():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 2 '
cmd += 'python {}/broadcast_impl_ascend_cast_and_async_impl.py {}'.format(cur_dir, backend)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.')
@SKIP_ENV_CPU(reason='distribute op not supported on CPU')
def test_allreduce_dtype():
cur_dir = os.path.abspath(os.path.dirname(__file__))
cmd = 'mpirun --allow-run-as-root -n 2 '
cmd += 'python {}/all_reduce_dtype_impl.py {}'.format(cur_dir, backend)
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True)
except subprocess.CalledProcessError as e:
raise Exception(e.output.decode())

+ 145
- 22
testing/ut/pytorch/autograd/test_autograd_function.py View File

@@ -7,7 +7,7 @@ from mindspore import context
import mindtorch.torch as ms_torch
from ...utils import set_mode_by_env_config
from mindtorch.torch.autograd import Function
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, enable_backward, param_compare

set_mode_by_env_config()

@@ -56,6 +56,29 @@ def torch_autograd_function():
grad_out = x.grad, y.grad
return out, grad_out

# TODO: backward not support custom bprop yet.
# def adapter_autograd_function_backward():
# class Net(ms_torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, y):
# result = ms_torch.matmul(x, y)
# ctx.save_for_backward(x, y)
# return result

# @staticmethod
# def backward(ctx, grad_output):
# x, y = ctx.saved_tensors
# dx = x + 1
# dy = y + 1
# return dx, dy

# x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True)
# y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True)
# with enable_backward():
# out = Net.apply(x, y)
# out.backward()
# grad_out = x.grad, y.grad
# return out, grad_out

def test_autograd_funciton():
ms_out, ms_grad_out = adapter_autograd_function()
@@ -64,29 +87,129 @@ def test_autograd_funciton():
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

@SKIP_ENV_GRAPH_MODE(reason="Error testing, unnecessary for Graph mode")
def test_autograd_funciton_error():
class Net(Function):
def forward(ctx, x, y):
result = torch.matmul(x, y)
ctx.save_for_backward(x, y)
return result

def backward(ctx, grad_output):
x, y = ctx.saved_tensors
dx = x + 1
dy = y + 1
return dx, dy

x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True)
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True)
try:
out = Net.apply(x, y)
except Exception as e:
assert "To create a custom autograd.Function" in str(e)
# TODO: backward not support custom bprop yet.
# @SKIP_ENV_GRAPH_MODE(reason="tensor.backward not support graphmode")
# def test_autograd_function_bprop_backward():
# ms_out, ms_grad_out = adapter_autograd_function_backward()
# pt_out, pt_grad_out = torch_autograd_function()
# assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy())
# assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
# assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode")
def test_autograd_function_grad():
def adapter_autograd_function_ms_grad():
class Net(ms_torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
result = ms_torch.matmul(x, y)
ctx.save_for_backward(x, y)
return result

@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
dx = x + 1
dy = y + 1
return dx, dy

x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True)
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True)
def _func(a, b):
out = Net.apply(a, b)
return out
out, grad_out = ms.value_and_grad(_func, grad_position=(0, 1))(x, y)
return out, grad_out
ms_out, ms_grad_out = adapter_autograd_function_ms_grad()
pt_out, pt_grad_out = torch_autograd_function()
assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy())
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy())
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy())

@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode")
def test_autograd_function_grad_bias_None():
def ms_torch_func():
class TestFunction(ms_torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, bias, has_bias):
result = ms_torch.matmul(input1, input2)
ctx.save_for_backward(result, bias)
ctx.has_bias = has_bias

if has_bias:
result = result + bias

return result.sum()

@staticmethod
def backward(ctx, grad_outputs):
result, bias = ctx.saved_tensors
has_bias = ctx.has_bias
result = grad_outputs * result
if has_bias:
result = result + bias
# TODO: not support bias gradient auto-reducesum.
# return result + 1, result + 2, result + 3, None
return result + 1, result + 2, (result + 3).sum(dim=1), None

input1 = ms_torch.ones([8, 8])
input2 = ms_torch.ones([8, 8])
bias = ms_torch.ones([8])

input1.requires_grad = True
input2.requires_grad = True
bias.requires_grad = True

def _func(x, y, z):
return TestFunction.apply(x, y, z, True)
grads = ms.grad(_func, grad_position=(0, 1, 2))(input1, input2, bias)
return grads

def torch_fun():
class TestFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2, bias, has_bias):
result = torch.matmul(input1, input2)
ctx.save_for_backward(result, bias)
ctx.has_bias = has_bias

if has_bias:
result = result + bias

return result.sum()

@staticmethod
def backward(ctx, grad_outputs):
result, bias = ctx.saved_tensors
has_bias = ctx.has_bias
result = grad_outputs * result
if has_bias:
result = result + bias
# return result + 1, result + 2, result + 3, None
return result + 1, result + 2, (result + 3).sum(dim=1), None

input1 = torch.ones([8, 8])
input2 = torch.ones([8, 8])
bias = torch.ones([8])

input1.requires_grad = True
input2.requires_grad = True
bias.requires_grad = True

def _func(x, y, z):
return TestFunction.apply(x, y, z, True)

res = _func(input1, input2, bias)
res.backward()
return input1.grad, input2.grad, bias.grad

ms_grad = ms_torch_func()
torch_grad = torch_fun()
param_compare(ms_grad, torch_grad)

if __name__ == '__main__':
set_mode_by_env_config()
test_autograd_funciton()
test_autograd_funciton_error()
# test_autograd_function_bprop_backward()
test_autograd_function_grad()
test_autograd_function_grad_bias_None()

Loading…
Cancel
Save