#895 update autograd.Function and distribute api

Merged
zoulq merged 37 commits from frelam/MSAdapter:master0319 into master 2 weeks ago
frelam commented 1 month ago
zoulq reviewed 1 month ago
@@ -29,3 +58,3 @@
super(Function, self).__init__()
self.ctx = FunctionCtx()

def apply(self, *args, **kwargs):
zoulq commented 1 month ago
原来使用这个接口会报错提示用mindspore对应接口,现在不会提示但会在brop入参的地方报错,用户应该是看不懂的,所有在用户资料里要更新一下自定义算子章节,另外FAQ加个样例说明。
frelam commented 3 weeks ago
用动态生成函数的方式, 在__init__阶段自动生成了bprop, 当前用法可以与pytorch相同了。
zoulq commented 2 weeks ago
cell的这个功能后面会优化
zoulq reviewed 1 month ago
@@ -37,0 +70,4 @@
return input_ms, _origin_dtype

# should use before cast_to_adapter_tensor
def _recorver_dtype_on_ascend(output_ms, _origin_dtype):
zoulq commented 1 month ago
什么场景下会用到这两个类型转换接口?
frelam commented 1 month ago
Ascend上, 用户输入tensor的dtype, 在mindspore通信算子侧不支持时, 会用到该类型转换。
zoulq reviewed 1 month ago
@@ -492,2 +556,3 @@
_bc_op = _get_cache_prim(ms.ops.Broadcast)(src, _group_name)
tensor.data = _bc_op((tensor,))[0]
if get_backend(group) == "hccl":
cast_tensor, _origin_dtype = _check_and_convert_dtype_on_ascend(tensor)
zoulq commented 1 month ago
这里什么场景需要做类型转换?看起来对性能有影响
frelam commented 3 weeks ago
在Ascend上, 如果输入的dtype, 不在通信算子支持的 (int8, int32, float16, float32, bfloat16) 类型里面, 会触发类型转换, 其余情况相当于透传。
zoulq reviewed 1 month ago
mindtorch/torch/optim/optimizer.py
@@ -255,6 +255,13 @@ class _Optimizer:
loss = None
if closure is not None:
loss = closure()
if grads is None:
zoulq commented 1 month ago
入参现在是必选参数,要同步改可选这个判断才有意义吧?
frelam commented 1 month ago
是的 。 这个在#890 统一修改就行, 那边是正确的。
zoulq reviewed 1 month ago
mindtorch/torch/optim/optimizer.py
@@ -258,0 +260,4 @@
for param_group in self.param_groups:
for param in param_group['param']:
_grad = param.grad if param.grad is not None else ms.ops.zeros_like(param)
grads.append(_grad)
zoulq commented 1 month ago
这个获取grad的处理,能否提前做?
frelam commented 3 weeks ago
无法提前做 。 因为grad每次都是新的tensor对象, 需要每次都重新获取。
zoulq reviewed 1 month ago
mindtorch/torch/autograd/function.py
@@ -42,0 +77,4 @@
def backward(ctx, *grad_outputs):
pass

def bprop(self, *args, **kwargs):
zoulq commented 1 month ago
要增加测试用例
frelam commented 3 weeks ago
已添加。 已知两点限制: 1. 当返回梯度shape与参数shape不一致时, 自动规约不支持。 2. Function.apply 图模式下不支持。
frelam reviewed 1 month ago
testing/ut/pytorch/cuda/test_stream.py
@@ -8,6 +8,8 @@ import mindspore as ms
import numpy as np
from mindspore import jit, grad

user_stream1 = ms_torch.cuda.Stream()
frelam commented 1 month ago
复用stream , 可节省用例显存消耗。 验证用例可通过。
frelam reviewed 1 month ago
mindtorch/torch/autograd/function.py
@@ -36,0 +80,4 @@
"your custom autograd.Function to use it with backward "
"mode AD.")

def bprop(self, *args, **kwargs):
frelam commented 1 month ago
当前bprop不支持不定长输入。
zoulq commented 2 weeks ago
已规划优化
frelam changed title from [WIP]update some api to update some api 3 weeks ago
frelam changed title from update some api to update autograd.Function and distribute api 3 weeks ago
Erpim reviewed 2 weeks ago
@@ -5,0 +25,4 @@
self.dirty_tensors = args

def mark_non_differentiable(self, *args):
self.non_differentiable = args
Erpim commented 2 weeks ago
这些功能实际不生效?
frelam commented 2 weeks ago
是的。 根据影响添加了warning或者报错。
zoulq merged commit 64f60e8971 into master 2 weeks ago
The pull request has been merged as 64f60e8971.
Sign in to join this conversation.
No reviewers
No Label
No Milestone
No Assignees
3 Participants
Notifications
Due Date

No due date set.

Dependencies

This pull request currently doesn't have any dependencies.

Loading…
There is no content yet.