#890 adapter to new autograd

Merged
Erpim merged 45 commits from frelam/MSAdapter:master0311 into master 1 month ago
  1. +45
    -15
      mindtorch/torch/cuda/amp/grad_scaler.py
  2. +39
    -0
      mindtorch/torch/distributed/utils.py
  3. +14
    -2
      mindtorch/torch/nn/modules/module.py
  4. +29
    -7
      mindtorch/torch/nn/parallel/distributed.py
  5. +35
    -4
      mindtorch/torch/nn/utils/clip_grad.py
  6. +25
    -9
      mindtorch/torch/optim/optimizer.py
  7. +61
    -4
      testing/ut/pytorch/amp/test_clip_grad.py
  8. +99
    -2
      testing/ut/pytorch/amp/test_grad_scaler.py
  9. +22
    -20
      testing/ut/pytorch/cuda/test_stream.py
  10. +208
    -1
      testing/ut/pytorch/optim/test_optim.py

+ 45
- 15
mindtorch/torch/cuda/amp/grad_scaler.py View File

@@ -4,8 +4,9 @@ from enum import Enum
import mindspore as ms
from mindspore.amp import DynamicLossScaler, all_finite
import mindspore.ops as ops
from mindspore.common import mutable
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.torch.tensor import tensor
from mindtorch.torch.tensor import tensor, cast_to_ms_tensor
from mindtorch.torch.common.dtype import float32, int32
from mindtorch.torch.logging import warning
from mindtorch.utils import graph_mode_condition
@@ -22,7 +23,7 @@ def _assign(x1, x2):
return x1.assign_value(x2)

_hypermap = ops.HyperMap()
_partial = ops.Partial()
class GradScaler(DynamicLossScaler):
def __init__(self,
init_scale=2.**16,
@@ -61,26 +62,46 @@ class GradScaler(DynamicLossScaler):
def _check_inf(self, grads):
return {'all': ms.ops.logical_not(all_finite(grads))}

def _loss_scale(self, scale, loss):
return loss * scale.astype(loss.dtype)

def _loss_scale_map(self, scale_value, inputs):
return _hypermap(_partial(self._loss_scale, scale_value), inputs)

def scale(self, outputs):
if not self._enabled:
return outputs
return DynamicLossScaler.scale(self, outputs)
frelam commented 1 month ago
Review
mindspore的DynamicLossScaler.scale中的实现, 使用了jit。 但是新微分, 在jit下不能使用。 所以需要重新实现一个不带jit的版本。
outputs = mutable(outputs)
return self._loss_scale_map(self.scale_value, outputs)

def unscale_(self, optimizer, grads):
def unscale_(self, optimizer, grads=None):
if not self._enabled:
return

if graph_mode_condition():
raise RuntimeError("Under graph mode, GradScalar not support unscale_(), please use unscale(). "
"Example: change 'scaler.unscale_(optimizer)' to "
"'grads = scaler.unscale(optimizer, grads)'")

optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError("unscale_() has already been called on this optimizer since the last update().")
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")

if grads is None:
grads = [cast_to_ms_tensor(p.grad) for p in optimizer.parameters if p.grad is not None]
if len(grads) == 0:
return
grads = tuple(grads)
optimizer_state['found_inf_per_device'] = self._check_inf(grads)
new_grads = DynamicLossScaler.unscale(self, grads)
for i, p in enumerate(optimizer.parameters):
if p.grad is not None:
p.grad = new_grads[i]
return

optimizer_state['found_inf_per_device'] = self._check_inf(grads)
if graph_mode_condition():
raise RuntimeError("Under graph mode, GradScalar not support unscale_(), please use unscale(). "
frelam commented 1 month ago
Review
报错前移
"Example: change 'scaler.unscale_(optimizer, grads)' to "
"'grads = scaler.unscale(optimizer, grads)'")
_hypermap(_assign, grads, DynamicLossScaler.unscale(self, grads))
optimizer_state["stage"] = OptState.UNSCALED

@@ -93,15 +114,15 @@ class GradScaler(DynamicLossScaler):
optimizer_state["stage"] = OptState.UNSCALED
return DynamicLossScaler.unscale(self, grads)

def _maybe_opt_step(self, optimizer, grads, optimizer_state, *args, **kwargs):
frelam commented 1 month ago
Review
内部接口, grads后移, 通过args做判断。 移动后, 同pytorch定义
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not sum(v.asnumpy().tolist() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(grads, *args, **kwargs)
retval = optimizer.step(*args, **kwargs)
return retval

def step(self, optimizer, grads, *args, **kwargs):
def step(self, optimizer, *args, **kwargs):
if not self._enabled:
return optimizer.step(grads)
return optimizer.step(*args, **kwargs)

if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")
@@ -128,7 +149,7 @@ class GradScaler(DynamicLossScaler):
found_inf = optimizer_state["found_inf_per_device"]
optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler
optimizer.found_inf = found_inf
retval = optimizer.step(grads, *args, **kwargs_)
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
@@ -136,9 +157,15 @@ class GradScaler(DynamicLossScaler):
return retval

if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer, grads)
# To see if grads is pass in.
if len(args) > 0 and isinstance(args[0], tuple) and \
zoulq commented 1 month ago
Review
用原来grads参数会有什么问题
frelam commented 1 month ago
Review
如果用法是step(optimizer, grad), 没有影响, 能够兼容, 这里会识别有没有传入grad。 如果用法是step(optimizer, grads=grad), 改后会出现错误。
len(args[0]) > 0 and isinstance(args[0][0], ms.Tensor):
grads = args[0]
self.unscale_(optimizer, grads)
else:
self.unscale_(optimizer)

retval = self._maybe_opt_step(optimizer, grads, optimizer_state, *args, **kwargs)
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)

optimizer_state["stage"] = OptState.STEPPED
return retval
@@ -175,6 +202,9 @@ class GradScaler(DynamicLossScaler):
found_infs = [found_inf
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()]
if len(found_infs) == 0:
raise ValueError("No inf checks were recorded prior to update."
"Maybe no grad has been unscaled in 'unscale_' process.")
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):


+ 39
- 0
mindtorch/torch/distributed/utils.py View File

@@ -0,0 +1,39 @@
from mindtorch.torch.distributed.distributed_c10d import broadcast
from mindtorch.utils import unsupported_attr

def _sync_module_states(
module,
process_group,
broadcast_bucket_size,
src,
params_and_buffers_to_ignore,
):
module_states = []
for name, param in module.named_parameters():
if name not in params_and_buffers_to_ignore:
module_states.append(param)

for name, buffer in module.named_buffers():
if name not in params_and_buffers_to_ignore:
module_states.append(buffer)

_sync_params_and_buffers(
process_group,
module_states,
broadcast_bucket_size,
src
)

def _sync_params_and_buffers(
process_group,
module_states,
broadcast_bucket_size,
src,
):
unsupported_attr(broadcast_bucket_size)
if len(module_states) > 0:
for state in module_states:
_state = state.detach()
_state_dtype = state.dtype
broadcast(_state.float(), src, process_group)
state.assign_value(_state.astype(_state_dtype))

+ 14
- 2
mindtorch/torch/nn/modules/module.py View File

@@ -12,7 +12,7 @@ from mindspore import Tensor as ms_Tensor
from mindtorch.torch.overrides import is_tensor_like
from mindtorch.torch.tensor import Tensor, _dtypeDict, cast_to_ms_tensor
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.utils import unsupported_attr
from mindtorch.utils import unsupported_attr, graph_mode_condition
from mindtorch.torch.types import device as device_class
from mindtorch.torch.functional import empty_like
from mindtorch.torch.logging import warning
@@ -1038,4 +1038,16 @@ class Module(Cell):
return sorted(keys)

def zero_grad(self, set_to_none=True):
unsupported_attr(set_to_none)
if graph_mode_condition():
return

for p in self.parameters():
if p.grad is not None:
zoulq commented 1 month ago
Review
如果不用新微分方案,用户自己挂上grad场景,下面会不会报错?
frelam commented 1 month ago
Review
不会, 这个场景在“test_sgd_step_no_grads”这个用例中测试了
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.assign_value(ms.ops.zeros_like(p.grad))

+ 29
- 7
mindtorch/torch/nn/parallel/distributed.py View File

@@ -4,7 +4,8 @@ from mindspore import ops
import mindspore as ms

from mindtorch.torch import distributed as dist
from mindtorch.torch.distributed.distributed_c10d import _get_pg_name
from mindtorch.torch.distributed.distributed_c10d import _get_pg_name, all_reduce, _get_default_group
from mindtorch.torch.distributed.utils import _sync_module_states
from mindtorch.torch.nn.modules.module import Module
from mindtorch.utils import unsupported_attr
from mindtorch.torch.logging import warning
@@ -28,18 +29,34 @@ class DistributedDataParallel(Module):
ms.set_auto_parallel_context(comm_fusion={"allreduce": {"mode": "size", "config": bucket_cap_mb}})
if ms.get_context('mode') == ms.PYNATIVE_MODE:
warning("`bucket_cap_mb` takes effect only in graph mode.")

self.network = module
device_num = dist.get_world_size(process_group)
pg_name = GlobalComm.WORLD_COMM_GROUP if process_group is None else _get_pg_name(process_group)
self.grad_reducer = DistributedGradReducer(module.trainable_params(), degree=device_num, group=pg_name)
self.broadcast = ops.Broadcast(0, pg_name)

self.module = module
if process_group is None:
self.process_group = _get_default_group()
else:
self.process_group = process_group

self.modules_buffers = list()
self.broadcast_buffers = broadcast_buffers
if broadcast_buffers:
for buffer in module.buffers():
self.modules_buffers.append(buffer)
self.broadcast = ops.Broadcast(0, pg_name)

self.broadcast_bucket_size = int(250 * 1024 * 1024)
Erpim commented 1 month ago
Review
该数据是根据什么计算出来的?和环境硬件相关?还是业界有统一的默认值?
frelam commented 1 month ago
Review
业界有统一的默认值. pytorch也是这么取默认值。
frelam commented 1 month ago
Review
当前bucket功能还没有起作用, 这个参数没有起到实际作用。 bucket通信需要后面再研究下。
# TODO: not support 'parameters_to_ignore' now, because it is used by 'delay_all_reduce_named_params',
# but 'delay_all_reduce_named_params' relies on Parameter's hook, which is not support yet.
self.parameters_to_ignore = set()
_sync_module_states(
frelam commented 1 month ago
Review
同pytorch,新增多卡上buffer和 param的同步。 已在resnet50上验证, 最终精度可提高, 90个epoch后接近pytorch精度。
module=self.module,
process_group=self.process_group,
broadcast_bucket_size=self.broadcast_bucket_size,
src=0,
params_and_buffers_to_ignore=self.parameters_to_ignore,
)

unsupported_attr(device_ids)
unsupported_attr(output_device)
@@ -62,10 +79,15 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs):
if self.will_sync_module_buffers():
self._sync_buffers()
return self.network(*inputs, **kwargs)
return self.module(*inputs, **kwargs)
frelam commented 1 month ago
Review
同pytorch,修改成员变量名称

def all_reduce(self, grads):
grads = self.grad_reducer(grads)
def all_reduce(self, grads=None):
if grads is None:
for p in self.module.parameters():
if p.grad is not None:
all_reduce(p.grad, group=self.process_group)
else:
grads = self.grad_reducer(grads)
return grads

def gather(self, outputs, output_device):


+ 35
- 4
mindtorch/torch/nn/utils/clip_grad.py View File

@@ -2,7 +2,8 @@ import mindspore as ms
from mindspore.ops.function.clip_func import get_square_sum, apply_global_norm
from mindspore import _checkparam as Validator
from mindtorch.utils import unsupported_attr, graph_mode_condition
from mindtorch.torch.tensor import cast_to_adapter_tensor
from mindtorch.torch.tensor import cast_to_adapter_tensor, Tensor, cast_to_ms_tensor, tensor
from mindtorch.torch.nn.parameter import Parameter

__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value', 'clip_grad_value_']

@@ -50,12 +51,29 @@ def clip_grad_norm(parameters, max_norm, grads, norm_type=2.0, error_if_nonfinit
new_grads, total_norm = _ClipByGlobalNorm(max_norm, None)(grads)
return new_grads, cast_to_adapter_tensor(total_norm)

def clip_grad_norm_(parameters, max_norm, grads, norm_type=2.0, error_if_nonfinite=False, foreach=None):
def clip_grad_norm_(parameters, max_norm, norm_type=2.0,
error_if_nonfinite=False, foreach=None, grads=None):
if graph_mode_condition():
raise RuntimeError("Under graph mode, adapter not support in-place operation. "
"So please use 'clip_grad_norm' to replace 'clip_grad_norm_'")

new_grads, total_norm = clip_grad_norm(parameters, max_norm, grads, norm_type, error_if_nonfinite, foreach)
if isinstance(parameters, (Tensor, Parameter)):
parameters = [parameters]
if grads is None:
_param = list(parameters)
grads = [p.grad for p in _param if p.grad is not None]
if len(grads) == 0:
return tensor(0.)
grads = cast_to_ms_tensor(grads)
grads = tuple(grads)
new_grads, total_norm = clip_grad_norm(parameters, max_norm, grads, norm_type,
error_if_nonfinite, foreach)
for i, p in enumerate(_param):
p.grad = new_grads[i]
return total_norm

new_grads, total_norm = clip_grad_norm(parameters, max_norm, grads, norm_type,
error_if_nonfinite, foreach)
_hypermap(_assign, grads, new_grads)
return total_norm

@@ -72,10 +90,23 @@ def clip_grad_value(parameters, clip_value, grads, foreach=None):
grads = ms.ops.clip_by_value(grads, _clip_value_min, _clip_value_max)
return grads

def clip_grad_value_(parameters, clip_value, grads, foreach=None):
def clip_grad_value_(parameters, clip_value, foreach=None, grads=None):
if graph_mode_condition():
raise RuntimeError("Under graph mode, adapter not support in-place operation. "
"So please use 'clip_grad_value' to replace 'clip_grad_value_'")

if isinstance(parameters, (Tensor, Parameter)):
parameters = [parameters]
if grads is None:
_param = list(parameters)
grads = [p.grad for p in _param if p.grad is not None]
if len(grads) == 0:
return
grads = cast_to_ms_tensor(grads)
grads = tuple(grads)
new_grads = clip_grad_value(parameters, clip_value, grads, foreach)
for i, p in enumerate(_param):
p.grad = new_grads[i]
return
new_grads = clip_grad_value(parameters, clip_value, grads, foreach)
_hypermap(_assign, grads, new_grads)

+ 25
- 9
mindtorch/torch/optim/optimizer.py View File

@@ -6,7 +6,7 @@ from itertools import chain
import mindspore as ms
from mindspore.experimental.optim import Optimizer as Optimizer_MS
from mindtorch.torch.tensor import Tensor, tensor, cast_to_ms_tensor
from mindtorch.utils import unsupported_attr
from mindtorch.utils import unsupported_attr, graph_mode_condition

def _warn_differentiable(differentiable):
if differentiable:
@@ -251,21 +251,37 @@ class _Optimizer:
else:
_load_from_pt(ms_params, name)

def step(self, grads, closure=None):
def step(self, grads=None, closure=None):
loss = None
if closure is not None:
loss = closure()
if grads is None:
grads = [param.grad if param.grad is not None
else ms.ops.zeros_like(param) for param in self.parameters]
# Has to turn 'grads' to tuple type before sending to 'construct'
# Otherwise, it will cause recompiling every step, which will lead to poor performance.
grads = tuple(grads)
ret = self.construct(grads)
if closure is not None:
ret = loss
return ret

def zero_grad(self):
raise NotImplementedError("'zero_grad' not support yet because of different autograd mechanism "
"between MindSpore and PyTorch. Actually we usually don't need to "
"call 'zero_grad' in MindTorch, because 'mindspore.grad' or 'value_and_grad' always "
"return the new grad without accumulation, so there is no need to clear "
"the grad.")
def zero_grad(self, set_to_none=True):
if graph_mode_condition():
return

for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.assign_value(ms.ops.zeros_like(p.grad))


class _OptimizerMeta(abc.ABCMeta, type(Optimizer_MS)):
"""
@@ -287,7 +303,7 @@ class Optimizer(_Optimizer, Optimizer_MS, metaclass=_OptimizerMeta):
return True
return NotImplemented

def step(self, grads, closure=None):
def step(self, grads=None, closure=None):
raise NotImplementedError

def _is_tensor(obj):


+ 61
- 4
testing/ut/pytorch/amp/test_clip_grad.py View File

@@ -1,7 +1,7 @@
import mindspore as ms
import torch

from ...utils import param_compare, SKIP_ENV_GRAPH_MODE
from ...utils import param_compare, SKIP_ENV_GRAPH_MODE, enable_backward
import mindtorch.torch as ms_torch

@@ -12,7 +12,7 @@ def test_clip_grad_norm_():

l = ms_torch.nn.Linear(2, 2)
ms_grads = ms.ops.arange(1., 5).view(2, 2), ms.ops.ones(2).div(1000)
ms_total_norm = ms_torch.nn.utils.clip_grad_norm_(l.parameters(), max_norm, ms_grads,
ms_total_norm = ms_torch.nn.utils.clip_grad_norm_(l.parameters(), max_norm, grads=ms_grads,
norm_type=norm_type)

l = torch.nn.Linear(2, 2)
@@ -38,7 +38,7 @@ def test_clip_grad_value_():
def test_case(value):
l = ms_torch.nn.Linear(10, 10)
ms_grads = ms.ops.arange(-50., 50).view(10, 10).div(5), ms.ops.ones(10).mul(2)
ms_torch.nn.utils.clip_grad_value_(l.parameters(), value, ms_grads)
ms_torch.nn.utils.clip_grad_value_(l.parameters(), value, grads=ms_grads)

l = torch.nn.Linear(10, 10)
grads = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
@@ -53,7 +53,64 @@ def test_clip_grad_value_():
for value in [2.5, -2.5]:
test_case(value)

@SKIP_ENV_GRAPH_MODE(reason="clip_grad_norm_ not support graph mode.")
def test_clip_grad_norm_autograd():
with enable_backward():
max_norm = 2
norm_type = 2.0

l = ms_torch.nn.Linear(2, 2)
ms_grads = ms.ops.arange(1., 5).view(2, 2), ms.ops.ones(2).div(1000)
for p, g in zip(l.parameters(), ms_grads):
p.grad = g
ms_total_norm = ms_torch.nn.utils.clip_grad_norm_(l.parameters(), max_norm,
norm_type=norm_type)
_ms_param = list(l.parameters())
ms_grad_norm1 = ms_torch.norm(ms_torch.cast_to_adapter_tensor(_ms_param[0].grad))
ms_grad_norm2 = ms_torch.norm(ms_torch.cast_to_adapter_tensor(_ms_param[1].grad))

l = torch.nn.Linear(2, 2)
grads = torch.arange(1., 5).view(2, 2), torch.ones(2).div(1000)
for p, g in zip(l.parameters(), grads):
p.grad = g.clone().view_as(p.data)
pt_total_norm = torch.nn.utils.clip_grad_norm_(l.parameters(), max_norm,
norm_type=norm_type)
_param = list(l.parameters())
pt_grad_norm1 = torch.norm(_param[0].grad)
pt_grad_norm2 = torch.norm(_param[1].grad)

param_compare(ms_total_norm, pt_total_norm)
param_compare(ms_grad_norm1, pt_grad_norm1)
param_compare(ms_grad_norm2, pt_grad_norm2)
param_compare(_ms_param[0].grad, _param[0].grad)
param_compare(_ms_param[1].grad, _param[1].grad)

@SKIP_ENV_GRAPH_MODE(reason="clip_grad_norm_ not support graph mode.")
def test_clip_grad_value_autograd():
def test_case(value):
l = ms_torch.nn.Linear(10, 10)
ms_grads = ms.ops.arange(-50., 50).view(10, 10).div(5), ms.ops.ones(10).mul(2)
for p, g in zip(l.parameters(), ms_grads):
p.grad = g
ms_torch.nn.utils.clip_grad_value_(l.parameters(), value)
ms_param = list(l.parameters())

l = torch.nn.Linear(10, 10)
grads = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
for p, g in zip(l.parameters(), grads):
p.grad = g.clone().view_as(p.data)
torch.nn.utils.clip_grad_value_(l.parameters(), value)
_param = list(l.parameters())

param_compare(ms_param[0].grad, _param[0].grad)
param_compare(ms_param[1].grad, _param[1].grad)

with enable_backward():
for value in [2.5, -2.5]:
test_case(value)

if __name__ == '__main__':
test_clip_grad_norm_()
test_clip_grad_value_()
test_clip_grad_value_()
test_clip_grad_norm_autograd()
test_clip_grad_value_autograd()

+ 99
- 2
testing/ut/pytorch/amp/test_grad_scaler.py View File

@@ -5,7 +5,7 @@ import mindspore as ms
from mindspore import context
import mindtorch.torch as ms_torch
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, param_compare, \
SKIP_ENV_CPU, SKIP_ENV_ASCEND
SKIP_ENV_CPU, SKIP_ENV_ASCEND, enable_backward

set_mode_by_env_config()

@@ -292,6 +292,102 @@ def test_gradscaler_disable():
ms_scaler.update()
assert pt_scaler.get_scale() == ms_scaler.get_scale()

# @SKIP_ENV_CPU(reason="torch only support GradScaler on GPU.")
@SKIP_ENV_ASCEND(reason="torch only support GradScaler on GPU.")
@SKIP_ENV_GRAPH_MODE(reason="unscale_() not support in GraphMode")
def test_grad_scalar_autograd():
_inputs = np.random.randn(3, 3).astype(np.float32)
_target = 2.0

def torch_scaler():
class Model(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(2.0).to(torch.float32))

def forward(self, inputs):
return (inputs * self.a).sum()

class Cri(torch.nn.Module):
def forward(self, out, target):
return out - target

model = Model().cuda()
# model = Pt_Model()
critirion = Cri()

inputs = torch.tensor(_inputs).to("cuda")
target = torch.tensor(_target).to(torch.float32).to("cuda")
# inputs = torch.tensor(_inputs)
# target = torch.tensor(_target).to(torch.float32)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

scaler = torch.cuda.amp.GradScaler(init_scale=2.**8, growth_factor=1.6, growth_interval=1)
# with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
with torch.autocast(device_type="cuda", dtype=torch.float16):
out = model(inputs)
loss = critirion(out, target)

scaler.scale(loss).backward()
scaler.unscale_(optimizer) # unscale the gradients
scaler.step(optimizer) # optimizer.step()
scaler.update() # 更新scaler
pt_result = model.a.cpu().detach()
pt_scale = scaler.get_scale()
return pt_result, pt_scale

# adapter
def ms_scaler():
class Model(ms_torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.a = ms_torch.nn.Parameter(ms_torch.tensor(2.0).to(ms_torch.float32))

def forward(self, inputs):
return (inputs * self.a).sum()

class Cri(ms_torch.nn.Module):
def forward(self, out, target):
return out - target

model = Model()
critirion = Cri()

inputs = ms_torch.tensor(_inputs)
target = ms_torch.tensor(_target).to(ms_torch.float32)
optimizer = ms_torch.optim.SGD(model.parameters(), lr=0.1)

scaler = ms_torch.cuda.amp.GradScaler(init_scale=2.**8, growth_factor=1.6, growth_interval=1)
class Net(ms_torch.nn.Module):
def __init__(self, model, critirion):
super().__init__()
self.model = model
self.critirion = critirion

def forward(self, inputs, target):
out = self.model(inputs)
loss = self.critirion(out, target)
return loss
net = Net(model, critirion)
net = ms.amp.auto_mixed_precision(net)
loss = net(inputs, target)

scaler.scale(loss).backward()
net.model.a.grad = ms.Tensor(500000.)
scaler.unscale_(optimizer) # unscale the gradients
scaler.step(optimizer) # optimizer.step()
scaler.update() # 更新scaler

ms_result = model.a.detach()
ms_scale = scaler.get_scale()
return ms_result, ms_scale

with enable_backward():
#pt_result, pt_scale = torch_scaler()
ms_result, ms_scale = ms_scaler()

# param_compare(pt_result, ms_result)
#assert pt_scale == ms_scale

if __name__ == '__main__':
test_grad_scalar()
@@ -302,4 +398,5 @@ if __name__ == '__main__':
test_grad_inf_not_step()
test_grad_scalar()
test_one_gradscaler_two_optimizer()
test_gradscaler_disable()
test_gradscaler_disable()
test_grad_scalar_autograd()

+ 22
- 20
testing/ut/pytorch/cuda/test_stream.py View File

@@ -8,6 +8,8 @@ import mindspore as ms
import numpy as np
from mindspore import jit, grad

user_stream1 = ms_torch.cuda.Stream()
user_stream2 = ms_torch.cuda.Stream()

from ...utils import set_mode_by_env_config, SKIP_ENV_CPU
set_mode_by_env_config()
@@ -31,7 +33,7 @@ def test_cuda_Stream_repr():
@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_eq():
default_stream = ms_torch.cuda.current_stream()
user_stream = ms_torch.cuda.Stream()
user_stream = user_stream1
assert ms_torch.cuda.current_stream() == default_stream
assert default_stream != user_stream
with ms_torch.cuda.stream(user_stream):
@@ -40,8 +42,8 @@ def test_cuda_Stream_eq():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_hash():
stream1 = ms_torch.cuda.Stream()
stream2 = ms_torch.cuda.Stream()
stream1 = user_stream1
stream2 = user_stream2
stream3 = ms_torch.cuda.Stream(stream=stream2)

assert stream1 != stream2
@@ -53,7 +55,7 @@ def test_cuda_Stream_hash():
def test_cuda_Stream_query():
a = ms_torch.ones(1024, 2048, dtype=ms_torch.float32, device="cuda")
b = ms_torch.ones(2048, 4096, dtype=ms_torch.float32, device="cuda")
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1

with ms_torch.cuda.stream(stream1):
ms_torch.mm(a, b).to('cuda')
@@ -64,7 +66,7 @@ def test_cuda_Stream_query():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_record_event():
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1
curr_stream = ms_torch.cuda.current_stream()
event = ms_torch.cuda.Event()

@@ -80,7 +82,7 @@ def test_cuda_Stream_record_event():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_synchronize():
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1

# A with large shape to ensure it run for a long time and will be synchronized by stream1.
A = ms_torch.rand(5000, 5000, device="cuda")
@@ -92,8 +94,8 @@ def test_cuda_Stream_synchronize():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_wait_event():
stream1 = ms_torch.cuda.Stream()
stream2 = ms_torch.cuda.Stream()
stream1 = user_stream1
stream2 = user_stream2
event = ms_torch.cuda.Event()

# A with large shape to ensure it run for a long time and will be synchronized by stream2.
@@ -113,8 +115,8 @@ def test_cuda_Stream_wait_event():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_Stream_wait_stream():
stream1 = ms_torch.cuda.Stream()
stream2 = ms_torch.cuda.Stream()
stream1 = user_stream1
stream2 = user_stream2

A = ms_torch.rand(5000, 5000, device="cuda")
with ms_torch.cuda.stream(stream1):
@@ -130,7 +132,7 @@ def test_cuda_Stream_wait_stream():

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_jit_stream():
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1
event = ms_torch.cuda.Event()
a = ms_torch.ones([1, 2], dtype=ms_torch.float32, device="cuda")
b = ms_torch.ones([2], dtype=ms_torch.float32, device="cuda")
@@ -154,7 +156,7 @@ def test_cuda_grad_stream():
grad_fn = grad(func)

a = ms_torch.tensor([0.62, 0.29, 0.45, 0.38], dtype=ms_torch.float32, device="cuda")
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1
event = ms_torch.cuda.Event()
a *= 4
event.record()
@@ -168,8 +170,8 @@ def test_cuda_grad_stream():
@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_data_dependency_between_streams():
prev_curr_stream = ms_torch.cuda.current_stream()
stream1 = ms_torch.cuda.Stream(priority=0)
stream2 = ms_torch.cuda.Stream(priority=0)
stream1 = user_stream1
stream2 = user_stream2
event = ms_torch.cuda.Event(False, False, False)

A = ms_torch.rand(500, 500, device="cuda")
@@ -195,8 +197,8 @@ def test_cuda_data_dependency_between_streams():
@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.Stream is not supported on cpu.")
def test_cuda_multi_streams():
prev_curr_stream = ms_torch.cuda.current_stream()
stream1 = ms_torch.cuda.Stream(priority=0)
stream2 = ms_torch.cuda.Stream(priority=0)
stream1 = user_stream1
stream2 = user_stream2

A = ms_torch.rand(5000, 5000, device="cuda")
B = ms_torch.rand(5000, 5000, device="cuda")
@@ -222,15 +224,15 @@ def test_cuda_multi_streams():
def test_cuda_set_stream():
curr_stream = ms_torch.cuda.current_stream()
assert curr_stream == ms_torch.cuda.default_stream()
stream1 = ms_torch.cuda.Stream()
stream1 = user_stream1
ms_torch.cuda.set_stream(stream1)
assert stream1 == ms_torch.cuda.current_stream()
ms_torch.cuda.set_stream(ms_torch.cuda.default_stream())

@SKIP_ENV_CPU(reason="mindtorch.torch.cuda.synchronize is not supported on cpu.")
def test_cuda_synchronize():
stream1 = ms_torch.cuda.Stream()
stream2 = ms_torch.cuda.Stream()
stream1 = user_stream1
stream2 = user_stream2

A = ms_torch.rand(500, 500, device="cuda")
B = ms_torch.rand(500, 500, device="cuda")
@@ -247,7 +249,7 @@ def test_cuda_synchronize():
def test_cuda_stream():
curr_stream = ms_torch.cuda.current_stream()
default_stream = ms_torch.cuda.default_stream()
user_stream = ms_torch.cuda.Stream()
user_stream = user_stream1
is_curr_and_default_stream_same = (curr_stream == default_stream)
is_user_and_default_stream_not_same = (user_stream != default_stream)
with ms_torch.cuda.stream(user_stream):


+ 208
- 1
testing/ut/pytorch/optim/test_optim.py View File

@@ -4,7 +4,7 @@ import mindspore as ms
import numpy as np
import torch

from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, SKIP_ENV_PYNATIVE_MODE
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, SKIP_ENV_PYNATIVE_MODE, enable_backward
set_mode_by_env_config()

@SKIP_ENV_GRAPH_MODE(reason='inplace op in step() not support graph mode')
@@ -88,6 +88,211 @@ def test_sgd():
torch_result2 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result2, torch_result2)

@SKIP_ENV_GRAPH_MODE("optimizer.step without grad do not support graph mode.")
def test_sgd_step_no_grads():
weight = ms_torch.nn.Parameter(ms_torch.tensor([1, 2]).to(ms_torch.float))
opt = ms_torch.optim.SGD([weight], lr=0.01)

# step 1
for group in opt.param_groups:
group['lr'] = 0.2
weight.grad = ms_torch.tensor([3, 4.])
opt.step()
opt.zero_grad()
ms_result1 = opt.param_groups[0]['params'][0].detach().numpy()
# step 2
for group in opt.param_groups:
group['lr'] = 0.5
weight.grad = ms_torch.tensor([3, 4.])
opt.step()
ms_result2 = opt.param_groups[0]['params'][0].detach().numpy()

weight = torch.nn.Parameter(torch.tensor([1, 2]).to(torch.float))
opt = torch.optim.SGD([weight], lr=0.01)
for group in opt.param_groups:
group['lr'] = 0.2
# step 1
weight.grad = torch.tensor([3, 4.])
opt.step()
opt.zero_grad()
torch_result1 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result1, torch_result1)

# step 2
for group in opt.param_groups:
group['lr'] = 0.5
weight.grad = torch.tensor([3, 4.])
opt.step()
torch_result2 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result2, torch_result2)

@SKIP_ENV_GRAPH_MODE("backward not support graph mode.")
def test_sgd_autograd():
with enable_backward():
weight = ms_torch.nn.Parameter(ms_torch.tensor([1, 2]).to(ms_torch.float))
opt = ms_torch.optim.SGD([weight], lr=0.01)
def ms_torch_func(x):
return (x * 2).sum()

# step 1
for group in opt.param_groups:
group['lr'] = 0.2
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward()
opt.step()
ms_result1 = opt.param_groups[0]['params'][0].detach().numpy()
# step 2
for group in opt.param_groups:
group['lr'] = 0.5
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward()
opt.step()
ms_result2 = opt.param_groups[0]['params'][0].detach().numpy()

weight = torch.nn.Parameter(torch.tensor([1, 2]).to(torch.float))
opt = torch.optim.SGD([weight], lr=0.01)
def torch_func(x):
return (x * 2).sum()

for group in opt.param_groups:
group['lr'] = 0.2
# step 1
opt.zero_grad()
loss = torch_func(weight)
loss.backward()
opt.step()
torch_result1 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result1, torch_result1)

# step 2
for group in opt.param_groups:
group['lr'] = 0.5
weight.grad = torch.tensor([3, 4.])
opt.zero_grad()
loss = torch_func(weight)
loss.backward()
opt.step()
torch_result2 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result2, torch_result2)

@SKIP_ENV_GRAPH_MODE(reason='retain_graph of backward not support yet.')
@SKIP_ENV_PYNATIVE_MODE(reason='retain_graph of backward not support yet.')
def test_sgd_autograd_retain_graph():
with enable_backward():
weight = ms_torch.nn.Parameter(ms_torch.tensor([1, 2]).to(ms_torch.float))
opt = ms_torch.optim.SGD([weight], lr=0.01)
def ms_torch_func(x):
return (x * 2).sum()

# step 1
for group in opt.param_groups:
group['lr'] = 0.2
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward(retain_graph=True)
loss.backward()
opt.step()
ms_result1 = opt.param_groups[0]['params'][0].detach().numpy()
# step 2
for group in opt.param_groups:
group['lr'] = 0.5
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward(retain_graph=True)
loss.backward()
opt.step()
ms_result2 = opt.param_groups[0]['params'][0].detach().numpy()

weight = torch.nn.Parameter(torch.tensor([1, 2]).to(torch.float))
opt = torch.optim.SGD([weight], lr=0.01)
def torch_func(x):
return (x * 2).sum()

for group in opt.param_groups:
group['lr'] = 0.2
# step 1
opt.zero_grad()
loss = torch_func(weight)
loss.backward(retain_graph=True)
loss.backward()
opt.step()
torch_result1 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result1, torch_result1)

# step 2
for group in opt.param_groups:
group['lr'] = 0.5
weight.grad = torch.tensor([3, 4.])
opt.zero_grad()
loss = torch_func(weight)
loss.backward(retain_graph=True)
loss.backward()
opt.step()
torch_result2 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result2, torch_result2)

@SKIP_ENV_GRAPH_MODE("backward not support graph mode.")
def test_sgd_autograd_double_backward():
with enable_backward():
weight = ms_torch.nn.Parameter(ms_torch.tensor([1, 2]).to(ms_torch.float))
opt = ms_torch.optim.SGD([weight], lr=0.01)
def ms_torch_func(x):
return (x * 2).sum()

# step 1
for group in opt.param_groups:
group['lr'] = 0.2
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward()
loss = ms_torch_func(weight)
loss.backward()
opt.step()
ms_result1 = opt.param_groups[0]['params'][0].detach().numpy()
# step 2
for group in opt.param_groups:
group['lr'] = 0.5
opt.zero_grad()
loss = ms_torch_func(weight)
loss.backward()
loss = ms_torch_func(weight)
loss.backward()
opt.step()
ms_result2 = opt.param_groups[0]['params'][0].detach().numpy()

weight = torch.nn.Parameter(torch.tensor([1, 2]).to(torch.float))
opt = torch.optim.SGD([weight], lr=0.01)
def torch_func(x):
return (x * 2).sum()

for group in opt.param_groups:
group['lr'] = 0.2
# step 1
opt.zero_grad()
loss = torch_func(weight)
loss.backward()
loss = torch_func(weight)
loss.backward()
opt.step()
torch_result1 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result1, torch_result1)

# step 2
for group in opt.param_groups:
group['lr'] = 0.5
weight.grad = torch.tensor([3, 4.])
opt.zero_grad()
loss = torch_func(weight)
loss.backward()
loss = torch_func(weight)
loss.backward()
opt.step()
torch_result2 = opt.param_groups[0]['params'][0].detach().numpy()
assert np.allclose(ms_result2, torch_result2)


# [CI] ms2.3 0327 still not fix.
@SKIP_ENV_GRAPH_MODE(reason="MindSpore has some bug at 'group['lr'] *= 0.2' situation.")
@SKIP_ENV_PYNATIVE_MODE(reason="MindSpore has some bug at 'group['lr'] *= 0.2' situation.")
@@ -1199,3 +1404,5 @@ if __name__ == '__main__':
test_sgd_multi_group()
test_adamax_jit()
test_sgd_state_dict_load_from_pytorch_int_args()
test_sgd_autograd()
test_sgd_autograd_double_backward()

Loading…
Cancel
Save