frelam/MSAdapter:master0319
into master
1 month ago
@@ -1,7 +1,42 @@ | |||
#!/usr/bin/env python | |||
# -*- coding: utf-8 -*- | |||
from types import FunctionType, MethodType | |||
from mindtorch.utils import unsupported_attr | |||
from mindtorch.torch.nn import Module | |||
from mindtorch.torch.tensor import Tensor, cast_to_adapter_tensor, cast_to_ms_tensor | |||
from mindtorch.torch.logging import warning | |||
class FunctionCtx: | |||
def save_for_backward(self, *tensors): | |||
self.to_save = tensors | |||
def save_for_forward(self, *tensors): | |||
for tensor in tensors: | |||
if not isinstance(tensor, Tensor) or tensor is None: | |||
raise TypeError( | |||
"save_for_forward expects all arguments to be tensors; you should " | |||
"save non-tensors as attributes on ctx." | |||
) | |||
self.saved_for_forward = tensors | |||
def mark_dirty(self, *args): | |||
warning("ctx.mark_dirty do not actually take effect now.") | |||
self.dirty_tensors = args | |||
|
|||
def mark_non_differentiable(self, *args): | |||
raise NotImplementedError("ctx.mark_non_differentiable not support yet.") | |||
def set_materialize_grads(self, value): | |||
if not value: | |||
warning("ctx.set_materialize_grads(false) not actually take effect now.") | |||
self.materialize_grads = value | |||
@property | |||
def saved_tensors(self): | |||
return self.to_save | |||
class Function(Module): | |||
@@ -27,15 +62,65 @@ class Function(Module): | |||
unsupported_attr(args) | |||
unsupported_attr(kwargs) | |||
super(Function, self).__init__() | |||
self.ctx = FunctionCtx() | |||
def compile_bprop(num_input): | |||
input_args = "" | |||
for i in range(num_input): | |||
if i < num_input - 1: | |||
input_args += f"input{i}," | |||
else: | |||
input_args += f"input{i}" | |||
input_args += ",out,dout" | |||
input_args_with_self = "self," + input_args | |||
code = f"""def bprop({input_args_with_self}): return self._backward_wrapper({input_args})""" | |||
code = compile(code, "<string>", "exec") | |||
func = FunctionType(code.co_consts[0], globals(), "bprop") | |||
return func | |||
def apply(self, *args, **kwargs): | |||
zoulq commented 2 months ago
Review
原来使用这个接口会报错提示用mindspore对应接口,现在不会提示但会在brop入参的地方报错,用户应该是看不懂的,所有在用户资料里要更新一下自定义算子章节,另外FAQ加个样例说明。 原来使用这个接口会报错提示用mindspore对应接口,现在不会提示但会在brop入参的地方报错,用户应该是看不懂的,所有在用户资料里要更新一下自定义算子章节,另外FAQ加个样例说明。
frelam commented 1 month ago
Review
用动态生成函数的方式, 在__init__阶段自动生成了bprop, 当前用法可以与pytorch相同了。 用动态生成函数的方式, 在__init__阶段自动生成了bprop, 当前用法可以与pytorch相同了。
zoulq commented 1 month ago
Review
cell的这个功能后面会优化 cell的这个功能后面会优化
|
|||
""" | |||
# Don not use it by calling the apply method. | |||
""" | |||
unsupported_attr(args) | |||
if not hasattr(self, 'bprop'): | |||
# num_input should remove ctx input, so minus 1. | |||
num_input = self.forward.__code__.co_argcount - 1 | |||
self.bprop = MethodType(compile_bprop(num_input), self) | |||
@staticmethod | |||
def forward(ctx, *args, **kwargs): | |||
raise NotImplementedError("You must implement the forward function for custom" | |||
" autograd.Function.") | |||
@classmethod | |||
def apply(cls, *args, **kwargs): | |||
obj = cls() | |||
return obj(*args, **kwargs) | |||
def construct(self, *args, **kwargs): | |||
return self.forward(self.ctx, *args, **kwargs) | |||
def _run_construct(self, cast_inputs, kwargs): | |||
return self.forward(self.ctx, *cast_inputs, **kwargs) | |||
@staticmethod | |||
def backward(ctx, *grad_outputs): | |||
raise NotImplementedError("You must implement either the backward method for " | |||
"your custom autograd.Function to use it with backward " | |||
"mode AD.") | |||
def _backward_wrapper(self, *args, **kwargs): | |||
unsupported_attr(kwargs) | |||
raise RuntimeError("To create a custom autograd.Function, please use 'def forward(self, ...)' and " | |||
"'def bprop(self, ..., out, dout)' instead of 'forward()' and 'backward()' static methods. " | |||
"Then, use it as normal module class, do not call the class method 'apply'." | |||
"Please refer to the following example: https://openi.pcl.ac.cn/OpenI/MSAdapter/src/" | |||
"branch/master/doc/torch/USER_GUIDE.md#user-content-4-2-1-%E8%87%AA%E5%AE%9A%E4%B9%89module") | |||
# Prev node may be a mindspore bprop node, and type of grad_outputs may be MindSpore Tensor. | |||
# But in backward, the computation will treat it as a MindTorch Tensor. | |||
# So add "cast_to_adapter_tensor" to ensure self.backward get a MindTorch Tensor. | |||
grad_outputs = cast_to_adapter_tensor(args[-1]) | |||
if isinstance(grad_outputs, (list, tuple)): | |||
res = self.backward(self.ctx, *grad_outputs) | |||
else: | |||
res = self.backward(self.ctx, grad_outputs) | |||
# Next Node may be a MindSpore bprop node, so need to "cast_to_ms_tensor" | |||
# to ensure next node get a MindSpore Tensor | |||
if res is None: | |||
res = 0 | |||
elif isinstance(res, (list, tuple)): | |||
res = tuple(0 if x is None else cast_to_ms_tensor(x) for x in res) | |||
else: | |||
res = cast_to_ms_tensor(res) | |||
return res |
@@ -1,2 +1,8 @@ | |||
from .distributed_c10d import * | |||
from ._distributed_c10d import * | |||
from .distributed_c10d import ( | |||
_backend, | |||
_all_gather_base, | |||
_reduce_scatter_base, | |||
_rank_not_in_group, | |||
) |
@@ -7,7 +7,9 @@ from mindspore.communication._comm_helper import (_is_available, _is_initialized | |||
import mindspore as ms | |||
from mindspore.ops._primitive_cache import _get_cache_prim | |||
from mindtorch.utils import unsupported_attr, graph_mode_condition | |||
from mindtorch.torch.common.dtype import int8, int32, float16, float32, bfloat16, all_complex_type | |||
from mindtorch.utils import unsupported_attr, graph_mode_condition, is_under_ascend_context, \ | |||
is_under_gpu_context | |||
from mindtorch.torch.logging import warning | |||
from mindtorch.torch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor, Tensor | |||
from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0611 | |||
@@ -34,6 +36,68 @@ from mindtorch.torch.distributed._distributed_c10d import ( # pylint: disable=W0 | |||
_group_count | |||
) | |||
_ascend_support_dtype = (int8, int32, float16, float32, bfloat16) | |||
_ascend_dtype_convert_map = { | |||
'bool': int8, | |||
'int64': int32, | |||
'float64': float32, | |||
} | |||
# should use after cast_to_ms_tensor | |||
def _check_and_convert_dtype_on_ascend(input_ms): | |||
def _convert(tensor): | |||
_origin_dtype = tensor.dtype | |||
if _origin_dtype in all_complex_type: | |||
raise TypeError("Not support communication of complex tensor yet.") | |||
if _origin_dtype not in _ascend_support_dtype: | |||
tensor = tensor.astype( | |||
_ascend_dtype_convert_map[str(_origin_dtype).split('.')[-1]] | |||
) | |||
else: | |||
_origin_dtype = None | |||
return tensor, _origin_dtype | |||
if isinstance(input_ms, Tensor): | |||
return _convert(input_ms) | |||
elif isinstance(input_ms, tuple): | |||
input_ms = list(input_ms) | |||
elif not isinstance(input_ms, list): | |||
raise TypeError("input_ms must be type of Tensor, tuple or list") | |||
_origin_dtype_list = [] | |||
for i, tensor in enumerate(input_ms): | |||
converted_tensor, origin_dtype = _convert(tensor) | |||
input_ms[i] = converted_tensor | |||
_origin_dtype_list.append(origin_dtype) | |||
zoulq commented 2 months ago
Review
什么场景下会用到这两个类型转换接口? 什么场景下会用到这两个类型转换接口?
frelam commented 1 month ago
Review
Ascend上, 用户输入tensor的dtype, 在mindspore通信算子侧不支持时, 会用到该类型转换。 Ascend上, 用户输入tensor的dtype, 在mindspore通信算子侧不支持时, 会用到该类型转换。
|
|||
return input_ms, _origin_dtype_list | |||
# should use before cast_to_adapter_tensor | |||
def _recorver_dtype_on_ascend(output_ms, _origin_dtype): | |||
def _recover(tensor, origin_dtype): | |||
if origin_dtype is not None: | |||
tensor = tensor.astype(origin_dtype) | |||
return tensor | |||
if isinstance(output_ms, Tensor): | |||
return _recover(output_ms, _origin_dtype) | |||
elif isinstance(output_ms, tuple): | |||
output_ms = list(output_ms) | |||
elif not isinstance(output_ms, list): | |||
raise TypeError(f"output_ms must be type of Tensor, tuple or list, but got {type(output_ms)}.") | |||
if not isinstance(_origin_dtype, list): | |||
raise TypeError(f"_origin_dtype must be type of list, but got {type(_origin_dtype)}.") | |||
if len(output_ms) != len(_origin_dtype): | |||
raise ValueError("length of output_ms not equal to _origin_dtype") | |||
for i, (output, dtype) in enumerate(zip(output_ms, _origin_dtype)): | |||
if dtype is not None: | |||
output_ms[i] = _recover(output, dtype) | |||
return output_ms | |||
BACKEND_DEVICE_TARGET_DICT = { | |||
'mccl': 'CPU', | |||
@@ -51,6 +115,8 @@ def _make_nccl_premul_sum(factor): | |||
raise NotImplementedError('distributed._make_nccl_premul_sum not support yet.' | |||
'Please manually scale the tensor before reduce.') | |||
_stub_work = Work() | |||
class Backend: | |||
UNDEFINED = "undefined" | |||
HCCL = "hccl" | |||
@@ -75,6 +141,9 @@ class Backend: | |||
raise NotImplementedError("For distributed.Backend, register_backend has not been supported yet. " | |||
"For now, only 'hccl', 'nccl', 'mccl' are supported.") | |||
_backend = Backend.UNDEFINED | |||
dist_backend = Backend | |||
# jit_class for _World to support graph mode | |||
@ms.jit_class | |||
class _World: | |||
@@ -225,6 +294,11 @@ def init_process_group( | |||
init() | |||
else: | |||
backend = Backend(backend) | |||
# Help user automatically shift between Ascend and GPU backend without changing user code. | |||
if backend == 'nccl' and is_under_ascend_context(): | |||
backend = 'hccl' | |||
elif backend == 'hccl' and is_under_gpu_context(): | |||
backend = 'nccl' | |||
init(backend) | |||
name = GlobalComm.WORLD_COMM_GROUP | |||
@@ -256,15 +330,12 @@ def new_group(ranks=None, | |||
timeout=None, | |||
backend=None, | |||
pg_options=None): | |||
unsupported_attr(backend) | |||
global _world | |||
if timeout is not None: | |||
raise NotImplementedError("distributed.new_group: timeout is not supported") | |||
if backend is not None: | |||
raise NotImplementedError("distributed.new_group: backend is not supported. " | |||
"Heterogeneity in a single process is not supported.") | |||
if pg_options is not None: | |||
raise NotImplementedError("distributed.new_group: pg_options is not supported." | |||
"Heterogeneity in a single process is not supported.") | |||
@@ -299,8 +370,8 @@ def new_group(ranks=None, | |||
pg = ProcessGroup(name=name) | |||
create_group(name, ranks) | |||
if not backend: | |||
backend = get_backend() | |||
# TODO: after support `backend` arg, use 'backend' rather than 'get_backend()'. | |||
backend = get_backend() | |||
_world.pg_names[pg] = name | |||
_world.pg_map[pg] = (backend,) | |||
@@ -392,6 +463,9 @@ def is_nccl_available(): | |||
def is_hccl_available(): | |||
return _is_hccl_available() | |||
def is_gloo_available(): | |||
return False | |||
def get_process_group_ranks(group): | |||
if _rank_not_in_group(group): | |||
_warn_not_in_group("get_process_group_ranks") | |||
@@ -425,7 +499,7 @@ def get_global_rank(group, group_rank): | |||
def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.all_reduce not support async yet.") | |||
warning("all_reduce: 'async_op' not actually supported now. Run as sync op") | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
@@ -446,7 +520,7 @@ def all_reduce_not_inplace(tensor, op=ReduceOp.SUM, group=None, async_op=False): | |||
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.all_reduce not support async yet.") | |||
warning("all_reduce: 'async_op' not actually supported now. Run as sync op") | |||
_inplace_raise_error_graph_mode('all_reduce', 'all_reduce_not_inplace') | |||
_check_single_tensor(tensor, "tensor") | |||
@@ -454,7 +528,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("all_reduce") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -465,12 +539,21 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): | |||
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op) | |||
else: | |||
_reduce_op = _get_cache_prim(ms.ops.AllReduce)(op, _group_name) | |||
result = _reduce_op(tensor) | |||
if get_backend(group) == "hccl": | |||
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor) | |||
result = _reduce_op(cast_tensor) | |||
result = _recorver_dtype_on_ascend(result, _origin_dtype) | |||
else: | |||
result = _reduce_op(tensor) | |||
tensor.data = result | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def broadcast(tensor, src, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.broadcast not support async yet.") | |||
warning("broadcast: 'async_op' not actually supported now. Run as sync op") | |||
_inplace_raise_error_graph_mode('broadcast', 'broadcast_not_inplace') | |||
@@ -479,7 +562,7 @@ def broadcast(tensor, src, group=None, async_op=False): | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("broadcast") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -490,11 +573,21 @@ def broadcast(tensor, src, group=None, async_op=False): | |||
else: | |||
src = get_group_rank(group, src) | |||
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src, _group_name) | |||
tensor.data = _bc_op((tensor,))[0] | |||
if get_backend(group) == "hccl": | |||
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor) | |||
result = _bc_op((cast_tensor,))[0] | |||
result = _recorver_dtype_on_ascend(result, _origin_dtype) | |||
else: | |||
result = _bc_op((tensor,))[0] | |||
tensor.data = result | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def all_gather_into_tensor_not_inplace(input_tensor, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.broadcast not support async yet.") | |||
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op") | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
@@ -515,7 +608,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal | |||
_inplace_raise_error_graph_mode('all_gather_into_tensor', 'all_gather_into_tensor_not_inplace') | |||
if async_op: | |||
raise NotImplementedError("distributed.broadcast not support async yet.") | |||
warning("all_gather_into_tensor: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(input_tensor, "input_tensor") | |||
_check_single_tensor(output_tensor, "output_tensor") | |||
@@ -523,7 +616,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("broadcast") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -548,6 +641,10 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal | |||
result = ms.ops.split(result, split_size) | |||
result = ms.ops.stack(result) | |||
output_tensor.data = result | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False): | |||
return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) | |||
@@ -569,14 +666,18 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("all_gather") | |||
return | |||
return None | |||
result = all_gather_not_inplace(tensor, group, async_op) | |||
for i, _tensor in enumerate(tensor_list): | |||
_tensor.data = result[i] | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def barrier(group=None, async_op=False, device_ids=None): | |||
if async_op: | |||
raise NotImplementedError("distributed.barrier not support async yet.") | |||
warning("barrier: 'async_op' not actually supported now. Run as sync op") | |||
if device_ids: | |||
raise NotImplementedError("distributed.barrier not support device_ids yet.") | |||
@@ -584,7 +685,7 @@ def barrier(group=None, async_op=False, device_ids=None): | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("barrier") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -595,15 +696,19 @@ def barrier(group=None, async_op=False, device_ids=None): | |||
else: | |||
_barrier_op = _get_cache_prim(ms.ops.operations._inner_ops.Barrier)(_group_name) | |||
_barrier_op() | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.barrier not support async yet.") | |||
warning("all_to_all: 'async_op' not actually supported now. Run as sync op") | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("all_to_all_single") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -626,6 +731,10 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False | |||
out = ms.ops.split(out, _spilit_size) | |||
for i, output in enumerate(output_tensor_list): | |||
output.data = out[i] | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def all_to_all_single( | |||
output, | |||
@@ -636,7 +745,7 @@ def all_to_all_single( | |||
async_op=False, | |||
): | |||
if async_op: | |||
raise NotImplementedError("distributed.barrier not support async yet.") | |||
warning("all_to_all_single: 'async_op' not actually supported now. Run as sync op") | |||
if output_split_sizes is not None or input_split_sizes is not None: | |||
raise NotImplementedError("all_to_all_single not support output_split_sizes and input_split_sizes now.") | |||
@@ -644,7 +753,7 @@ def all_to_all_single( | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("all_to_all_single") | |||
return | |||
return None | |||
_split_count = input.shape[0] | |||
_split_dim = 0 | |||
@@ -662,17 +771,21 @@ def all_to_all_single( | |||
input_ms = cast_to_ms_tensor(input) | |||
out = _op(input_ms) | |||
output.data = out | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.barrier not support async yet.") | |||
warning("reduce: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(tensor, "tensor") | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("reduce") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -690,6 +803,10 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): | |||
out = _reduce_op(tensor) | |||
if dst == get_rank(): | |||
tensor.data = out | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def send(tensor, dst, group=None, tag=0): | |||
if get_rank() == dst: | |||
@@ -753,7 +870,7 @@ def recv(tensor, src=None, group=None, tag=0): | |||
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.barrier not support async yet.") | |||
warning("reduce_scatter: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(output, "output") | |||
_check_tensor_list(input_list, "input_list") | |||
@@ -761,7 +878,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("reduce_scatter") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -776,10 +893,43 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal | |||
input_ms = ms.ops.concat(input_list) | |||
out = _op(input_ms) | |||
output.data = out | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): | |||
if async_op: | |||
warning("reduce_scatter_tensor: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(output, "output") | |||
_check_single_tensor(input, "input") | |||
if _rank_not_in_group(group): | |||
# Graph mode not support code below. | |||
# _warn_not_in_group("reduce_scatter_tensor") | |||
return None | |||
_group_name = _get_pg_name(group) | |||
if _group_name == GlobalComm.WORLD_COMM_GROUP: | |||
_op = _get_cache_prim(ms.ops.ReduceScatter)(op) | |||
else: | |||
_op = _get_cache_prim(ms.ops.ReduceScatter)(op, _group_name) | |||
input_ms = cast_to_ms_tensor(input) | |||
out = _op(input_ms) | |||
output.data = out | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=False): | |||
return reduce_scatter_tensor(output, input, op, group, async_op) | |||
def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.gather not support async yet.") | |||
warning("gather: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(tensor, "tensor") | |||
@@ -790,7 +940,7 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): | |||
if _rank_not_in_group(group): | |||
# _warn_not_in_group("gather") | |||
return | |||
return None | |||
my_rank = get_rank() | |||
_validate_output_list_for_rank(my_rank, dst, gather_list) | |||
@@ -812,10 +962,14 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): | |||
out = ms.ops.split(out, _spilit_size) | |||
for i, output in enumerate(gather_list): | |||
output.data = out[i] | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): | |||
if async_op: | |||
raise NotImplementedError("distributed.gather not support async yet.") | |||
warning("scatter: 'async_op' not actually supported now. Run as sync op") | |||
_check_single_tensor(tensor, "tensor") | |||
@@ -827,7 +981,7 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): | |||
if _rank_not_in_group(group): | |||
# _warn_not_in_group("scatter") | |||
return | |||
return None | |||
if group is None: | |||
group = _get_default_group() | |||
@@ -858,6 +1012,10 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): | |||
input_ms = ms.ops.zeros((group_size,) + tensor.shape, dtype=tensor.dtype) | |||
out = _op(input_ms)[0] | |||
tensor.data = out | |||
if async_op: | |||
return _stub_work | |||
else: | |||
return None | |||
@contextlib.contextmanager | |||
def _coalescing_manager(group, device, reqs): | |||
@@ -0,0 +1,37 @@ | |||
import sys | |||
import numpy as np | |||
import mindtorch.torch as torch | |||
import mindtorch.torch.distributed as dist | |||
def func(backend): | |||
_total_device = 2 | |||
dist.init_process_group(backend, world_size=_total_device) | |||
rank = dist.get_rank() | |||
w = torch.tensor([[1, 2., 3, 4.], [5, 6, 7, 8]]) | |||
x = torch.tensor([2., 2, 2, 2]) | |||
w_device = w[:, w.size(1) // _total_device * rank : w.size(1) // _total_device * (rank + 1)].to(rank) | |||
x_device = x[x.size(0) // _total_device * rank : x.size(0) // _total_device * (rank + 1)].to(rank) | |||
result = torch.matmul(w_device, x_device) | |||
_enum_dtype = {torch.bool, torch.int64} | |||
for dtype in _enum_dtype: | |||
_tmp = result.to(dtype) | |||
dist.all_reduce(_tmp) | |||
assert _tmp.dtype == dtype | |||
if dtype == torch.bool: | |||
assert np.allclose(_tmp.cpu().numpy(), np.array([True, True])) | |||
else: | |||
expected_result = np.array([20, 52.]) | |||
assert np.allclose(_tmp.cpu().numpy(), expected_result) | |||
assert _tmp.shape == expected_result.shape | |||
if __name__ == '__main__': | |||
backend = sys.argv[1] | |||
func(backend) |
@@ -0,0 +1,32 @@ | |||
import sys | |||
import numpy as np | |||
import mindtorch.torch as torch | |||
import mindtorch.torch.distributed as dist | |||
def func(backend): | |||
dist.init_process_group(backend) | |||
rank = dist.get_rank() | |||
if rank == 0: | |||
data = torch.tensor([1, 2.]).to(f'cuda:{rank}') | |||
else: | |||
data = torch.zeros(2).to(f'cuda:{rank}') | |||
_enum_dtype = {torch.bool, torch.int64} | |||
for dtype in _enum_dtype: | |||
_tmp = data.to(dtype) | |||
work = dist.broadcast(_tmp, 0, async_op=True) | |||
work.wait() | |||
assert _tmp.dtype == dtype | |||
if dtype == torch.bool: | |||
assert np.allclose(_tmp.cpu().numpy(), np.array([True, True])) | |||
else: | |||
assert np.allclose(_tmp.cpu().numpy(), np.array([1, 2.])) | |||
assert data.shape == (2,) | |||
if __name__ == '__main__': | |||
backend = sys.argv[1] | |||
func(backend) |
@@ -3,16 +3,17 @@ import numpy as np | |||
import mindspore as ms | |||
from mindspore import nn | |||
import mindtorch.torch as torch | |||
from mindtorch.torch.nn.parallel import DistributedDataParallel as DDP | |||
from mindtorch.torch import distributed as dist | |||
class NetWork(nn.Cell): | |||
class NetWork(torch.nn.Module): | |||
def __init__(self): | |||
super(NetWork, self).__init__() | |||
self.dense = nn.Dense(3, 3) | |||
self.dense = torch.nn.Linear(3, 3) | |||
def construct(self, x): | |||
def forward(self, x): | |||
return self.dense(x).sum() | |||
@@ -22,7 +23,7 @@ def ddp_basic(): | |||
dist.init_process_group(backend='hccl', rank=-1, world_size=2) | |||
network = NetWork() | |||
opt = nn.Adam(network.trainable_params()) | |||
opt = torch.optim.Adam(network.parameters()) | |||
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False) | |||
rank = dist.get_rank() | |||
@@ -34,7 +35,7 @@ def ddp_basic(): | |||
else: | |||
network_p = None | |||
inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32)) | |||
inputs = torch.tensor(np.random.random((2, 3)).astype(np.float32)) | |||
for _ in range(1): | |||
loss, grads = grad_fn(inputs) | |||
grads = network.all_reduce(grads) | |||
@@ -3,16 +3,17 @@ import numpy as np | |||
import mindspore as ms | |||
from mindspore import nn | |||
import mindtorch.torch as torch | |||
from mindtorch.torch.nn.parallel import DistributedDataParallel as DDP | |||
from mindtorch.torch import distributed as dist | |||
class NetWork(nn.Cell): | |||
class NetWork(torch.nn.Module): | |||
def __init__(self): | |||
super(NetWork, self).__init__() | |||
self.dense = nn.Dense(3, 3) | |||
self.dense = torch.nn.Linear(3, 3) | |||
def construct(self, x): | |||
def forward(self, x): | |||
return self.dense(x).sum() | |||
@@ -22,7 +23,7 @@ def ddp_basic(): | |||
dist.init_process_group(backend='nccl', rank=-1, world_size=3) | |||
network = NetWork() | |||
opt = nn.Adam(network.trainable_params()) | |||
opt = torch.optim.Adam(network.parameters()) | |||
grad_fn = ms.value_and_grad(network, None, opt.parameters, has_aux=False) | |||
rank = dist.get_rank() | |||
@@ -34,7 +35,7 @@ def ddp_basic(): | |||
else: | |||
network_p = None | |||
inputs = ms.Tensor(np.random.random((2, 3)).astype(np.float32)) | |||
inputs = torch.tensor(np.random.random((2, 3)).astype(np.float32)) | |||
for _ in range(1): | |||
loss, grads = grad_fn(inputs) | |||
grads = network.all_reduce(grads) | |||
@@ -0,0 +1,27 @@ | |||
import sys | |||
import numpy as np | |||
import mindtorch.torch as torch | |||
import mindtorch.torch.distributed as dist | |||
def func(backend): | |||
dist.init_process_group(backend) | |||
rank = dist.get_rank() | |||
# mindspore reduce_scatter not support int64 | |||
tensor = torch.arange(2, dtype=torch.float32).to(f'cuda:{rank}') + 2 * rank | |||
output = torch.empty(1, dtype=torch.float32).to(f'cuda:{rank}') | |||
dist.reduce_scatter_tensor(output, tensor) | |||
if rank == 0: | |||
assert np.allclose(output.cpu().numpy(), np.array([2]).astype(np.float32)) | |||
assert output.shape == (1,) | |||
else: | |||
assert np.allclose(output.cpu().numpy(), np.array([4]).astype(np.float32)) | |||
assert output.shape == (1,) | |||
if __name__ == '__main__': | |||
backend = sys.argv[1] | |||
func(backend) |
@@ -339,3 +339,36 @@ def test_allgather_grad(): | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) | |||
@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.') | |||
@SKIP_ENV_CPU(reason='distribute op not supported on CPU') | |||
def test_reduce_scatter_tensor(): | |||
cur_dir = os.path.abspath(os.path.dirname(__file__)) | |||
cmd = 'mpirun --allow-run-as-root -n 2 ' | |||
cmd += 'python {}/reduce_scatter_tensor_impl.py {}'.format(cur_dir, backend) | |||
try: | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) | |||
@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.') | |||
@SKIP_ENV_CPU(reason='distribute op not supported on CPU') | |||
def test_broadcast_cast_async(): | |||
cur_dir = os.path.abspath(os.path.dirname(__file__)) | |||
cmd = 'mpirun --allow-run-as-root -n 2 ' | |||
cmd += 'python {}/broadcast_impl_ascend_cast_and_async_impl.py {}'.format(cur_dir, backend) | |||
try: | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) | |||
@SKIP_ENV_GRAPH_MODE(reason='mindtorch distirbute not support graph mode yet.') | |||
@SKIP_ENV_CPU(reason='distribute op not supported on CPU') | |||
def test_allreduce_dtype(): | |||
cur_dir = os.path.abspath(os.path.dirname(__file__)) | |||
cmd = 'mpirun --allow-run-as-root -n 2 ' | |||
cmd += 'python {}/all_reduce_dtype_impl.py {}'.format(cur_dir, backend) | |||
try: | |||
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) | |||
except subprocess.CalledProcessError as e: | |||
raise Exception(e.output.decode()) |
@@ -7,7 +7,7 @@ from mindspore import context | |||
import mindtorch.torch as ms_torch | |||
from ...utils import set_mode_by_env_config | |||
from mindtorch.torch.autograd import Function | |||
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE | |||
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, enable_backward, param_compare | |||
set_mode_by_env_config() | |||
@@ -56,6 +56,29 @@ def torch_autograd_function(): | |||
grad_out = x.grad, y.grad | |||
return out, grad_out | |||
# TODO: backward not support custom bprop yet. | |||
# def adapter_autograd_function_backward(): | |||
# class Net(ms_torch.autograd.Function): | |||
# @staticmethod | |||
# def forward(ctx, x, y): | |||
# result = ms_torch.matmul(x, y) | |||
# ctx.save_for_backward(x, y) | |||
# return result | |||
# @staticmethod | |||
# def backward(ctx, grad_output): | |||
# x, y = ctx.saved_tensors | |||
# dx = x + 1 | |||
# dy = y + 1 | |||
# return dx, dy | |||
# x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True) | |||
# y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True) | |||
# with enable_backward(): | |||
# out = Net.apply(x, y) | |||
# out.backward() | |||
# grad_out = x.grad, y.grad | |||
# return out, grad_out | |||
def test_autograd_funciton(): | |||
ms_out, ms_grad_out = adapter_autograd_function() | |||
@@ -64,29 +87,129 @@ def test_autograd_funciton(): | |||
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy()) | |||
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy()) | |||
@SKIP_ENV_GRAPH_MODE(reason="Error testing, unnecessary for Graph mode") | |||
def test_autograd_funciton_error(): | |||
class Net(Function): | |||
def forward(ctx, x, y): | |||
result = torch.matmul(x, y) | |||
ctx.save_for_backward(x, y) | |||
return result | |||
def backward(ctx, grad_output): | |||
x, y = ctx.saved_tensors | |||
dx = x + 1 | |||
dy = y + 1 | |||
return dx, dy | |||
x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True) | |||
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True) | |||
try: | |||
out = Net.apply(x, y) | |||
except Exception as e: | |||
assert "To create a custom autograd.Function" in str(e) | |||
# TODO: backward not support custom bprop yet. | |||
# @SKIP_ENV_GRAPH_MODE(reason="tensor.backward not support graphmode") | |||
# def test_autograd_function_bprop_backward(): | |||
# ms_out, ms_grad_out = adapter_autograd_function_backward() | |||
# pt_out, pt_grad_out = torch_autograd_function() | |||
# assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy()) | |||
# assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy()) | |||
# assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy()) | |||
@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode") | |||
def test_autograd_function_grad(): | |||
def adapter_autograd_function_ms_grad(): | |||
class Net(ms_torch.autograd.Function): | |||
@staticmethod | |||
def forward(ctx, x, y): | |||
result = ms_torch.matmul(x, y) | |||
ctx.save_for_backward(x, y) | |||
return result | |||
@staticmethod | |||
def backward(ctx, grad_output): | |||
x, y = ctx.saved_tensors | |||
dx = x + 1 | |||
dy = y + 1 | |||
return dx, dy | |||
x = ms_torch.tensor([[0.5, 0.6, 0.4]], dtype=ms_torch.float32, requires_grad=True) | |||
y = ms_torch.tensor([[0.01], [0.2], [3.3]], dtype=ms_torch.float32, requires_grad=True) | |||
def _func(a, b): | |||
out = Net.apply(a, b) | |||
return out | |||
out, grad_out = ms.value_and_grad(_func, grad_position=(0, 1))(x, y) | |||
return out, grad_out | |||
ms_out, ms_grad_out = adapter_autograd_function_ms_grad() | |||
pt_out, pt_grad_out = torch_autograd_function() | |||
assert np.allclose(ms_out.asnumpy(), pt_out.detach().numpy()) | |||
assert np.allclose(ms_grad_out[0].asnumpy(), pt_grad_out[0].numpy()) | |||
assert np.allclose(ms_grad_out[1].asnumpy(), pt_grad_out[1].numpy()) | |||
@SKIP_ENV_GRAPH_MODE(reason="Funtion.apply not support graphmode") | |||
def test_autograd_function_grad_bias_None(): | |||
def ms_torch_func(): | |||
class TestFunction(ms_torch.autograd.Function): | |||
@staticmethod | |||
def forward(ctx, input1, input2, bias, has_bias): | |||
result = ms_torch.matmul(input1, input2) | |||
ctx.save_for_backward(result, bias) | |||
ctx.has_bias = has_bias | |||
if has_bias: | |||
result = result + bias | |||
return result.sum() | |||
@staticmethod | |||
def backward(ctx, grad_outputs): | |||
result, bias = ctx.saved_tensors | |||
has_bias = ctx.has_bias | |||
result = grad_outputs * result | |||
if has_bias: | |||
result = result + bias | |||
# TODO: not support bias gradient auto-reducesum. | |||
# return result + 1, result + 2, result + 3, None | |||
return result + 1, result + 2, (result + 3).sum(dim=1), None | |||
input1 = ms_torch.ones([8, 8]) | |||
input2 = ms_torch.ones([8, 8]) | |||
bias = ms_torch.ones([8]) | |||
input1.requires_grad = True | |||
input2.requires_grad = True | |||
bias.requires_grad = True | |||
def _func(x, y, z): | |||
return TestFunction.apply(x, y, z, True) | |||
grads = ms.grad(_func, grad_position=(0, 1, 2))(input1, input2, bias) | |||
return grads | |||
def torch_fun(): | |||
class TestFunction(torch.autograd.Function): | |||
@staticmethod | |||
def forward(ctx, input1, input2, bias, has_bias): | |||
result = torch.matmul(input1, input2) | |||
ctx.save_for_backward(result, bias) | |||
ctx.has_bias = has_bias | |||
if has_bias: | |||
result = result + bias | |||
return result.sum() | |||
@staticmethod | |||
def backward(ctx, grad_outputs): | |||
result, bias = ctx.saved_tensors | |||
has_bias = ctx.has_bias | |||
result = grad_outputs * result | |||
if has_bias: | |||
result = result + bias | |||
# return result + 1, result + 2, result + 3, None | |||
return result + 1, result + 2, (result + 3).sum(dim=1), None | |||
input1 = torch.ones([8, 8]) | |||
input2 = torch.ones([8, 8]) | |||
bias = torch.ones([8]) | |||
input1.requires_grad = True | |||
input2.requires_grad = True | |||
bias.requires_grad = True | |||
def _func(x, y, z): | |||
return TestFunction.apply(x, y, z, True) | |||
res = _func(input1, input2, bias) | |||
res.backward() | |||
return input1.grad, input2.grad, bias.grad | |||
ms_grad = ms_torch_func() | |||
torch_grad = torch_fun() | |||
param_compare(ms_grad, torch_grad) | |||
if __name__ == '__main__': | |||
set_mode_by_env_config() | |||
test_autograd_funciton() | |||
test_autograd_funciton_error() | |||
# test_autograd_function_bprop_backward() | |||
test_autograd_function_grad() | |||
test_autograd_function_grad_bias_None() |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》
这些功能实际不生效?
是的。 根据影响添加了warning或者报错。