#891 autograd_v3

Merged
zoulq merged 1 commits from erpim_0311 into master 1 month ago
Erpim commented 2 months ago
Erpim changed title from autograd_v3 to [WIP]autograd_v3 2 months ago
frelam reviewed 1 month ago
mindtorch/__init__.py
@@ -8,0 +8,4 @@
from mindspore._c_expression import jit_mode_pi_enable

_BACKWARD_ENV = os.environ.get('ENABLE_BACKWARD')
if _BACKWARD_ENV != "0":
frelam commented 1 month ago
如果没有配置 'ENABLE_BACKWARD', 'ENABLE_BACKWARD'是一个None值, 所以_BACKWARD_ENV != "0" 为真, 会默认使能。
Erpim commented 1 month ago
当前这个代码是默认使能,方便验证
Erpim changed title from [WIP]autograd_v3 to autograd_v3 1 month ago
zoulq reviewed 1 month ago
@@ -4,3 +4,4 @@
| ----------------- | ------------------------------------------------------------ | ---- | ------------------------------------------------------------ |
| MSA_LOG | 控制MindTorch日志的级别<br />(MindSpore的日志级别可通过GLOG_v配置,详情请参考[日志环境变量](https://www.mindspore.cn/docs/zh-CN/r2.1/note/env_var_list.html#%E6%97%A5%E5%BF%97)。)<br />注意,这是一个实验性环境变量,后续可能修改或删除。 | 整型 | 0: DEBUG<br />1: INFO<br />2: WARNING<br />3: ERROR<br />4: CRITICAL<br />默认值:2,指定日志级别后,将会输出大于或等于该级别的日志信息;<br /> |
| ENABLE_FORK_UTILS | 用于指定多进程创建方式,默认使用spawn方式创建多进程,配置后采用fork方式创建多进程;<br />注意,这是一个实验性环境变量,后续可能修改或删除。 | 整型 | 0: spawn方式创建多进程;<br />1: fork方式创建多进程;<br />默认值:1<br />注意:Windows环境下只能使用spawn方式创建多进程。 |
| ENABLE_BACKWARD | 使能tensor.backward()功能,允许用户按PyTorch的用法执行自动微分运算;<br />注意,这是一个实验性环境变量,后续可能修改或删除。 | 整型 | 0: 不使能backward功能;<br />1: 使能backward功能;<br />默认值:0 |
zoulq commented 1 month ago
说明一下开关不打开的时候用mindspore的微分方案适配,给个链接
Erpim commented 1 month ago
done
zoulq reviewed 1 month ago
@@ -56,3 +56,3 @@

#### 方式二 使用Tensor.backward对标接口(**实验功能**)
当前MindTorch正在开发对标`Tensor.backward()`接口功能,用户无需修改迁移前torch源码,迁移效率更高。需要注意的是,该功能当前为实验特性,存在如下使用约束:
MindTorch正在开发对标`Tensor.backward()`接口功能,用户无需修改迁移前torch源码,迁移效率更高。需要注意的是,该功能当前为实验特性,须用户配置环境变量`export ENABLE_BACKWARD=1`主动使能。目前存在如下使用约束:
zoulq commented 1 month ago
环境变量配置可以写在约束条件里面
Erpim commented 1 month ago
done
zoulq reviewed 1 month ago
mindtorch/torch/common/_inner.py
@@ -57,1 +75,4 @@
output.requires_grad = requires_grad
return cast_to_adapter_tensor(output)

if requires_grad is not None:
zoulq commented 1 month ago
重复代码,可以放到_out_assign_with_output公共函数里面去;另外图模式也不需要这些操作,统一放到_out_assign_with_output的pynative模式下
Erpim commented 1 month ago
done
zoulq reviewed 1 month ago
mindtorch/torch/nn/modules/linear.py
@@ -149,3 +149,3 @@
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(0, 0, False)
super().__init__(1, 1, False) # Currently, Parameter does not support contain zero dimension.
zoulq commented 1 month ago
修改后对功能有没有影响?后续的parameter的赋值涉及到shape的变化,当前框架多次变更shape的parameter存在bug
Erpim commented 1 month ago
该修改影响Lazy类接口确定shape前param的shape,但是一般不会去使用lazy状态下的param主要起占位作用。对实际场景影响较小,原本写法是空tensor的param转为具体shape的param,当前写法是(1,1)shape的param,转为具体shape的param,都涉及变更param的shape。多次变更shape的parameter的bug目前未识别
zoulq reviewed 1 month ago
@@ -65,9 +65,13 @@ class Parameter(ms.Parameter):
Parameter, (data, self.requires_grad, self.name, self.layerwise_parallel))

def __init__(self, data, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True):
self.adapter_flag = True
zoulq commented 1 month ago
这个往后放的原因是什么?之前放前面是图模式下部分场景需要
Erpim commented 1 month ago
只是为了属性赋值位置统一,目前没有识别到对图模式的影响。已改回去。
zoulq reviewed 1 month ago
mindtorch/torch/nn/parameter.py
@@ -68,3 +68,3 @@
self.adapter_flag = True
super().__init__(default_input=data, name=name, requires_grad=requires_grad,
layerwise_parallel=layerwise_parallel, parallel_optimizer=parallel_optimizer)
self._grad = None
zoulq commented 1 month ago
parameter.grad 现在是在哪一层定义?
Erpim commented 1 month ago
parameter.grad 继承的tensor.grad函数。不过param._grad定义在parameter内
zoulq reviewed 1 month ago
mindtorch/torch/tensor.py
@@ -264,0 +271,4 @@
self.tensor = input_data
self._grad = input_data._grad
self._grad_fn = input_data._grad_fn
self._requires_grad = input_data._requires_grad
zoulq commented 1 month ago
现在grad是挂在tensor上面,还是stubtensor上面?两个都会挂?
Erpim commented 1 month ago
当前两个都会挂
zoulq reviewed 1 month ago
mindtorch/torch/tensor.py
@@ -604,1 +621,4 @@

@property
def grad_fn(self):
return self._grad_fn
zoulq commented 1 month ago
如果没开开关,或者图模式下调用这个接口,预期是返回None还是会报错?
Erpim commented 1 month ago
预期返回None
zoulq reviewed 1 month ago
mindtorch/torch/tensor.py
@@ -605,0 +676,4 @@
_BACKWARD_ENV = os.environ.get('ENABLE_BACKWARD')
if _BACKWARD_ENV != "1":
raise NotImplementedError("If you want to use the `backward` function, please configure the environment "
"variable `export ENABLE_BACKWARD=1` to enable it first.")
zoulq commented 1 month ago
后面要验一下这个判断对性能的影响,backward会在每个step调用一次
Erpim commented 1 month ago
好的
zoulq reviewed 1 month ago
mindtorch/torch/tensor.py
@@ -605,0 +680,4 @@
unsupported_attr(retain_graph)
unsupported_attr(create_graph)
unsupported_attr(inputs)
gradient = cast_to_ms_tensor(gradient)
zoulq commented 1 month ago
这个gradient如果直接给到框架C++处理,这里理论上可以不用做cast转换
Erpim commented 1 month ago
done
zoulq reviewed 1 month ago
testing/ut/pytorch/autograd/test_autograd.py
@@ -0,0 +14,4 @@
output = self.Linear(input)
return output

# TODO: Execute after zero_grad adaptation
zoulq commented 1 month ago
zero_grad功能如果现在不支持,nn.Module.zero_grad接口的支持状态要同步刷新
Erpim commented 1 month ago
这个适配在pr #890,等下个pr合入的时候这个用例就可以放开了
zoulq reviewed 1 month ago
@@ -90,4 +89,0 @@
return input * 2

def ms_func(input1):
with ms_torch.no_grad():
zoulq commented 1 month ago
no_grad功能在新微分方案下是否可以正常使用?
Erpim commented 1 month ago
不可兼容,当前no_grad 和 enable_grad均无法使能新微分方案,用例先注释
zoulq reviewed 1 month ago
testing/ut/pytorch/nn/test_sparse.py
@@ -141,3 +100,1 @@
result_torch.backward()

param_compare(grads[0], torch_net.net.weight.grad)
# [CI] ms 2.3 0321 AUTOGRAD: There is an accuracy issue in backward
zoulq commented 1 month ago
原来用例为什么也不能跑了?
Erpim commented 1 month ago
可以跑,本来计划可以快速修复,所以就先注释,先已改回
zoulq reviewed 1 month ago
@@ -293,3 +15,1 @@
param_compare(eval_out0, wrapped_m(input))
param_compare(last_train_u, spectral_norm_m._u)
param_compare(last_train_v, spectral_norm_m._v)
# TODO: Execute after zero_grad adaptation
zoulq commented 1 month ago
同上,新增用例注释掉了,为什么原来用例不保留?
Erpim commented 1 month ago
zero_grad 在#890已适配,在那个pr打开验证
zoulq reviewed 1 month ago
@@ -296,0 +130,4 @@
# assert_is_orthogonal(m.weight)


# TODO: Execute after zero_grad adaptation
zoulq commented 1 month ago
原来用例为什么不能过了?
Erpim commented 1 month ago
zero_grad 在#890已适配,在那个pr打开验证
zoulq merged commit 2505e34fc8 into master 1 month ago
The pull request has been merged as 2505e34fc8.
Sign in to join this conversation.
No reviewers
No Label
No Milestone
No Assignees
3 Participants
Notifications
Due Date

No due date set.

Dependencies

This pull request currently doesn't have any dependencies.

Loading…
There is no content yet.