#829 use Tensor.backward

Closed
lvyufeng wants to merge 7 commits from lvyufeng/MSAdapter:master into master
  1. +5
    -1
      .gitignore
  2. +6
    -2
      mindtorch/__init__.py
  3. +41
    -28
      mindtorch/torch/amp/__init__.py
  4. +16
    -17
      mindtorch/torch/common/_inner.py
  5. +1
    -1
      mindtorch/torch/nn/__init__.py
  6. +94
    -70
      mindtorch/torch/nn/functional.py
  7. +2
    -4
      mindtorch/torch/nn/modules/activation.py
  8. +434
    -542
      mindtorch/torch/nn/modules/container.py
  9. +4
    -26
      mindtorch/torch/nn/modules/linear.py
  10. +1976
    -780
      mindtorch/torch/nn/modules/module.py
  11. +29
    -199
      mindtorch/torch/nn/parameter.py
  12. +2
    -2
      mindtorch/torch/optim/__init__.py
  13. +677
    -19
      mindtorch/torch/optim/adamw.py
  14. +701
    -154
      mindtorch/torch/optim/optimizer.py
  15. +257
    -25
      mindtorch/torch/optim/sgd.py
  16. +148
    -176
      mindtorch/torch/tensor.py
  17. +0
    -1
      mindtorch/torch/utils/data/_utils/collate.py
  18. +0
    -0
      testing/__init__.py
  19. +0
    -0
      testing/st/__init__.py
  20. +6
    -0
      testing/st/mindtorch/__init__.py
  21. +98
    -0
      testing/st/mindtorch/test_simple_linear.py
  22. +56
    -0
      testing/ut/pytorch/autograd/test_backward.py
  23. +3
    -1
      testing/ut/pytorch/nn/test_linear.py

+ 5
- 1
.gitignore View File

@@ -33,4 +33,8 @@ sdist/
var/
wheels/
#datasets/
#mnist/
#mnist/

*.ir

data/

+ 6
- 2
mindtorch/__init__.py View File

@@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from mindtorch import torch
from mindspore._c_expression import jit_mode_pi_enable, update_pijit_default_config
jit_mode_pi_enable()
update_pijit_default_config(auto_grad=True)

from . import torch
from mindtorch.utils import unsupported_attr, pynative_mode_condition
from mindtorch.package_info import __version__, VERSION, version
from mindtorch.package_info import __version__, VERSION, version

+ 41
- 28
mindtorch/torch/amp/__init__.py View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import mindspore as ms
import mindtorch.torch.nn as nn
from mindtorch.torch.nn import Module, Sequential
@@ -45,37 +46,49 @@ def auto_mixed_precision(network, amp_level="auto"):


# for mindspore auto mixed precision
ms.rewrite.symbol_tree_builder.SymbolTreeBuilder.entry_function = "forward"
ms.rewrite.parsers.class_def_parser.ClassDefParser.entry_function = "forward"
ms.rewrite.parsers.assign_parser.AssignParser.types_for_cell_container.append(Sequential)
try:
# [adapt old version ms] use 'try import' to suit mindspore 2.2
ms.rewrite.symbol_tree_builder.SymbolTreeBuilder.entry_function = "forward"
ms.rewrite.parsers.class_def_parser.ClassDefParser.entry_function = "forward"
ms.rewrite.parsers.assign_parser.AssignParser.types_for_cell_container.append(Sequential)

class ToDtype(Module):
def __init__(self):
super(ToDtype, self).__init__()
class ToDtype(Module):
def __init__(self):
super(ToDtype, self).__init__()

def forward(self, x, dtype):
return x.to(dtype)
def forward(self, x, dtype):
return x.to(dtype)

nn_modules_list = [ToDtype]
nn_modules = sys.modules['mindtorch.torch.nn']
nn_modules_dir = dir(nn_modules)
for module_name in nn_modules_dir:
module_obj = getattr(nn_modules, module_name)
if isinstance(module_obj, type) and issubclass(module_obj, Module):
nn_modules_list.append(module_obj)
nn_modules_list = [ToDtype]
nn_modules = sys.modules['mindtorch.torch.nn']
nn_modules_dir = dir(nn_modules)
for module_name in nn_modules_dir:
module_obj = getattr(nn_modules, module_name)
if isinstance(module_obj, type) and issubclass(module_obj, Module):
nn_modules_list.append(module_obj)

ms.rewrite.namespace._subtree_black_list.extend(nn_modules_list)
ms.train.amp._config_amp(enable_rewrite=True, cast_op=ToDtype)
ms.rewrite.namespace._subtree_black_list.extend(nn_modules_list)
ms.train.amp._config_amp(enable_rewrite=True, cast_op=ToDtype)

ms.train.amp.AMP_WHITE_LIST.extend([nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Linear,
nn.LSTMCell,
nn.RNNCell,
nn.GRUCell])
ms.train.amp.AMP_WHITE_LIST.extend([nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.Linear,
nn.LSTMCell,
nn.RNNCell,
nn.GRUCell])

ms.train.amp.AMP_BLACK_LIST.extend([nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm])
ms.train.amp.AMP_BLACK_LIST.extend([nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm])
except AttributeError:
ms.rewrite.symbol_tree.symbol_tree_builder.SymbolTreeBuilder.entry_functions = ["forward", "construct"]
ms.rewrite.parsers.class_def_parser.ClassDefParser.entry_functions = ["forward", "construct"]
ms.rewrite.parsers.class_def_parser.ClassDefParser.final_networks = ["Cell", "Module"]
ms.rewrite.parsers.assign_parser.AssignParser.types_for_cell_container.append(Sequential)
ms.rewrite.common.namespace._ignore_third_party_paths.append(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
ms.train.amp._config_amp(enable_rewrite=True)
ms.train.amp._INNER_AMP_BLACK_LIST.extend([ms.ops.operations._inner_ops.ConvertToMsTensor,
ms.ops.operations._inner_ops.ConvertToAdapterTensor])

+ 16
- 17
mindtorch/torch/common/_inner.py View File

@@ -59,20 +59,19 @@ def _out_inplace_assign(out, output, op_name):


def _inplace_assign_pynative(input, inplace, output, op_name):
if inplace is True:
if pynative_mode_condition(): # TODO: ms_function
warning(
'If you want to convert to the MindSpore static graph mode, `inplace` in `{}` should not be True, ' \
'please set inplace=False and use return value instead of `input`.'.format(op_name)
)
input.assign_value(output)
return input

raise ValueError(
'In MindSpore static graph mode, `inplace` in `{}` should not be True, ' \
'please set inplace=False and use return value instead of `input`.'.format(op_name)
)

# if inplace is True:
# if pynative_mode_condition(): # TODO: ms_function
# warning(
# 'If you want to convert to the MindSpore static graph mode, `inplace` in `{}` should not be True, ' \
# 'please set inplace=False and use return value instead of `input`.'.format(op_name)
# )
# input.tensor.assign_value(output.tensor)
# return input

# raise ValueError(
# 'In MindSpore static graph mode, `inplace` in `{}` should not be True, ' \
# 'please set inplace=False and use return value instead of `input`.'.format(op_name)
# )
return cast_to_adapter_tensor(output)


@@ -123,7 +122,7 @@ def _inplace_limit_pynative(inplace, op_name):
)

def _inplace_assign(input, inplace, output):
if inplace is True:
input.assign_value(output)
return input
# if inplace is True:
# input.assign_value(output)
# return input
return cast_to_adapter_tensor(output)

+ 1
- 1
mindtorch/torch/nn/__init__.py View File

@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-

from .modules import *
from .parameter import Parameter, ParameterTuple
from .parameter import Parameter
from . import init
from . import functional
from . import utils


+ 94
- 70
mindtorch/torch/nn/functional.py View File

@@ -5,6 +5,7 @@ from typing import Iterable
# from functools import lru_cache
import numpy as np
import mindspore as ms
from mindspore import ops
from mindspore.ops.primitive import _primexpr
from mindspore.ops._primitive_cache import _get_cache_prim
from mindspore.ops.function.math_func import _expand, _check_same_type
@@ -17,7 +18,7 @@ from mindtorch.torch.common._inner import _inplace_assign_pynative, _nn_function
from mindtorch.torch.common.dtype import all_int_type, all_float_and_complex_type
from mindtorch.torch.nn.modules.utils import _do_pad, _pair, _quadruple, _repeat_tuple, _single, _sextuple
from mindtorch.torch.common import pi
from mindtorch.torch.nn.modules.module import Module, Parameter
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.torch.logging import warning

all = [
@@ -365,8 +366,8 @@ def softshrink(input, lambd=0.5):


def relu(input, inplace=False):
input_ms = cast_to_ms_tensor(input)
out = ms.ops.relu(input_ms)
# input_ms = cast_to_ms_tensor(input)
out = ops.relu(input)
return _inplace_assign_pynative(input, inplace, out, "relu")

def relu_(input):
@@ -639,17 +640,18 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
"""
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)

input_ms = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
#TODO: mindspore currently not support int64
target_dtype = target.dtype
if target_dtype in all_int_type:
warning("cross_entropy: when target type is int64, there is risk of overflow.")
target = target.astype(ms.int32)
weight = cast_to_ms_tensor(weight)
# if weight is None:
# weight = ops.ones(input.shape[1], input.dtype)
# unsupport float64
result = ms.ops.cross_entropy(input_ms, target, weight, ignore_index, reduction, label_smoothing)
class_dim = 0 if input.ndim == 1 else 1
x = ops.log_softmax(input, class_dim)
# nll_loss_ = ops.NLLLoss(reduction, ignore_index)
# result, t_weight = nll_loss_(x, target, weight)
# print()
# cross_entropy_ = _get_cache_prim(ops.SparseSoftmaxCrossEntropyWithLogits)()
# result = cross_entropy_(input, target)
# print(result.requires_grad)
result = nll_loss(x, target, weight, ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing)
return cast_to_adapter_tensor(result)

def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
@@ -721,18 +723,77 @@ def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, r
return cast_to_adapter_tensor(rlt)

def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction="mean"):
reduce=None, reduction="mean", label_smoothing=0.0):
"""
The negative log likelihood loss.
"""
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)

input_ms = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
weight = cast_to_ms_tensor(weight)
result = ms.ops.nll_loss(input_ms, target, weight, ignore_index, reduction, label_smoothing=0.0)
return cast_to_adapter_tensor(result)
ndim = input.ndim
if ndim == 2:
ret = _nll_loss(input, target, -1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 4:
ret = _nll_loss(input, target, 1, weight, ignore_index, reduction, label_smoothing)
elif ndim == 1:
ret = _nll_loss(input, target, 0, weight, ignore_index, reduction, label_smoothing)
else:
n = input.shape[0]
c = input.shape[1]
out_size = (n,) + input.shape[2:]
input = input.view((n, c, 1, -1))
target = target.view((n, 1, -1))
if reduction != 'none':
ret = _nll_loss(input, target, 1, weight, ignore_index, reduction, label_smoothing)
else:
ret = _nll_loss(input, target, 1, weight, ignore_index, label_smoothing=label_smoothing)
ret = ret.view(out_size)
return cast_to_adapter_tensor(ret)


def _nll_loss(inputs, target, target_dim=-1, weight=None, ignore_index=None, reduction='none', label_smoothing=0.0):
"""nll loss inner function"""
if target.ndim == inputs.ndim - 1:
target = target.expand_dims(target_dim)
if ignore_index is not None:
non_pad_mask = ops.equal(target, ignore_index)
target = target.masked_fill(non_pad_mask, 0)
else:
non_pad_mask = target
if weight is not None:
loss_weights = ops.gather(weight, target, 0)
orig_shape = inputs.shape
if inputs.ndim != 2:
inputs = inputs.view(orig_shape[:2] + (-1,))
weight = weight.view(weight.shape + (1,))
weighted_inputs = inputs * weight
weighted_inputs = weighted_inputs.view(orig_shape)
loss = ops.neg(ops.gather_d(weighted_inputs, target_dim, target))
smooth_loss = ops.neg(weighted_inputs.sum(axis=target_dim, keepdims=True))
else:
loss = ops.neg(ops.gather_d(inputs, target_dim, target))
smooth_loss = ops.neg(inputs.sum(axis=target_dim, keepdims=True))
loss_weights = ops.ones_like(loss)
if ignore_index is not None:
loss = loss.masked_fill(non_pad_mask, 0.)
loss_weights = loss_weights.masked_fill(non_pad_mask, 0.)
smooth_loss = smooth_loss.masked_fill(non_pad_mask, 0.)

loss = loss.squeeze(target_dim)
smooth_loss = smooth_loss.squeeze(target_dim)

if reduction == 'sum':
loss = loss.sum()
smooth_loss = smooth_loss.sum()
if reduction == 'mean':
loss = loss.sum() / loss_weights.sum()
smooth_loss = smooth_loss.sum() / loss_weights.sum()

eps_i = label_smoothing / inputs.shape[target_dim]
loss = ops.scalar_to_tensor(1. - label_smoothing, loss.dtype) * loss + ops.scalar_to_tensor(eps_i, smooth_loss.dtype) * smooth_loss

return loss


def kl_div(input, target, size_average=None, reduce=None, reduction="mean", log_target=False):
"""
@@ -1129,15 +1190,15 @@ def cosine_embedding_loss(
loss = ms.ops.cosine_embedding_loss(input1, input2, target, margin, reduction)
return cast_to_adapter_tensor(loss)

class _PairwiseDisFun(Module):
def __init__(self, p=2.0, eps=1e-06, keepdim=False):
super(_PairwiseDisFun, self).__init__()
self.p = p
self.eps = eps
self.keepdim = keepdim
# class _PairwiseDisFun(Module):
# def __init__(self, p=2.0, eps=1e-06, keepdim=False):
# super(_PairwiseDisFun, self).__init__()
# self.p = p
# self.eps = eps
# self.keepdim = keepdim

def forward(self, x1, x2):
return pairwise_distance(x1, x2, p=self.p, eps=self.eps, keepdim=self.keepdim)
# def forward(self, x1, x2):
# return pairwise_distance(x1, x2, p=self.p, eps=self.eps, keepdim=self.keepdim)

def triplet_margin_loss(
anchor,
@@ -1769,61 +1830,24 @@ def _check_linear_shape(weight_rank, input_shape, weight_shape):
f"got input with shape {input_shape}, and weight with shape {weight_shape}.")

def linear(input, weight, bias=None):
input_ms = cast_to_ms_tensor(input)

dtype_op = _get_cache_prim(ms.ops.DType)()
rank_op = _get_cache_prim(ms.ops.Rank)()
shape_op = _get_cache_prim(ms.ops.Shape)()
reshape_op = _get_cache_prim(ms.ops.Reshape)()
bias_add_op = _get_cache_prim(ms.ops.BiasAdd)()

dtype1 = dtype_op(input_ms)
dtype2 = dtype_op(weight)
if not _check_same_type(dtype1, dtype2):
input_ms = input_ms.astype(ms.float32)
weight = weight.astype(ms.float32)

input_rank, weight_rank = rank_op(input_ms), rank_op(weight)
input_shape, weight_shape = shape_op(input_ms), shape_op(weight)
_check_linear_shape(weight_rank, input_shape, weight_shape)

# infers the shape of the output
shape_out = _get_linear_output_shape(input_shape, weight_shape, input_rank, weight_rank)

_matmul = _get_cache_prim(ms.ops.MatMul)(False, True)

input_ms = _expand(input_ms, 2)
weight = _expand(weight, 2)

if rank_op(input_ms) > 2:
input_ms = reshape_op(input_ms, (-1, input_shape[-1]))
output = _matmul(input_ms, weight)
if bias is not None:
bias = _expand(bias, 1)
# if output's rank bigger than 5, using output = ms.ops.add(output, bias)
output = bias_add_op(output, bias)
output = reshape_op(output, shape_out)
linear_ = _get_cache_prim(ops.Dense)()
output = linear_(input, weight, bias)
return cast_to_adapter_tensor(output)

def bilinear(input1, input2, weight, bias=None):
input1 = cast_to_ms_tensor(input1)
input2 = cast_to_ms_tensor(input2)
weight = cast_to_ms_tensor(weight)
input1_shape = input1.shape
input2_shape = input2.shape
if len(input1_shape) != 2:
input1 = input1.reshape((-1, input1_shape[-1]))
_matmul = _get_cache_prim(ms.ops.MatMul)(False, False)
x = _matmul(input1, weight.permute(1, 0, 2).reshape(weight.shape[1], -1))
x = ops.matmul(input1, weight.permute(1, 0, 2).reshape(weight.shape[1], -1))
if len(input2_shape) != 2:
input2 = input2.reshape((-1, input2_shape[-1]))
x = ms.ops.mul(x, ms.ops.tile(input2, (1, weight.shape[0])))
x = x * ops.tile(input2, (1, weight.shape[0]))
x = x.reshape(x.shape[0], weight.shape[0], -1)
x = ms.ops.reduce_sum(x, -1)
x = ops.reduce_sum(x, -1)
if bias is not None:
bias = cast_to_ms_tensor(bias)
# not support float64
x = ms.ops.bias_add(x, bias)
x = ops.bias_add(x, bias)
output = x.reshape(*input1_shape[:-1], -1)
return cast_to_adapter_tensor(output)



+ 2
- 4
mindtorch/torch/nn/modules/activation.py View File

@@ -50,14 +50,12 @@ class ReLU(Module):

def __init__(self, inplace=False):
super(ReLU, self).__init__()
self.relu = P.ReLU()
self.inplace = inplace
_inplace_limit_pynative(inplace, "ReLU")

def forward(self, input):
input_ms = cast_to_ms_tensor(input)
output = self.relu(input_ms)
return _inplace_assign(input, self.inplace, output)
return ms_torch_nn_func.relu(input)
# return _inplace_assign(input, self.inplace, output)

def extra_repr(self):
inplace_str = 'inplace=True' if self.inplace else ''


+ 434
- 542
mindtorch/torch/nn/modules/container.py View File

@@ -1,168 +1,163 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from abc import abstractmethod
import operator
from itertools import chain
from typing import Dict
import warnings
from collections import OrderedDict, abc as container_abcs
from mindspore.nn.layer.container import _get_prefix_and_index, _valid_index, _valid_cell
from itertools import chain, islice
import operator

from mindtorch.torch.tensor import Tensor, cast_to_adapter_tensor
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.torch._ref import typename
from mindtorch import torch
from .module import Module
from ..parameter import Parameter
from torch._jit_internal import _copy_to_script_wrapper

from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing_extensions import Self

__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']

T = TypeVar('T', bound=Module)


# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s


class Container(Module):

def __init__(self, **kwargs: Any) -> None:
super().__init__()
# DeprecationWarning is ignored by default <sigh>
warnings.warn("nn.Container is deprecated. All of it's functionality "
"is now implemented in nn.Module. Subclass that instead.")
for key, value in kwargs.items():
self.add_module(key, value)


class Sequential(Module):
r"""A sequential container.

Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``OrderedDict`` of modules can be
passed in. The ``forward()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.

The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).

What's the difference between a ``Sequential`` and a
:class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
sounds like--a list for storing ``Module`` s! On the other hand,
the layers in a ``Sequential`` are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
Sequential Module container. For more details about Module, please refer to

A list of Cells will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of cells can also be passed in.
_modules: Dict[str, Module] # type: ignore[assignment]

Note:
Sequential and nn.ModuleList are different, ModuleList is a list for storing modules. However,
the layers in a Sequential are connected in a cascading way.
@overload
def __init__(self, *args: Module) -> None:
...

@overload
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
...

Args:
args (list, OrderedDict): List or OrderedDict of subclass of Module.

Inputs:
- **x** (Tensor) - Tensor with shape according to the first Module in the sequence.

Outputs:
Tensor, the output Tensor with shape depending on the input `x` and defined sequence of Cells.

Raises:
TypeError: If the type of the `args` is not list or OrderedDict.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> relu = nn.ReLU()
>>> seq = nn.Sequential([conv, relu])
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
[27. 27.]]
[[27. 27.]
[27. 27.]]]]
>>> from collections import OrderedDict
>>> d = OrderedDict()
>>> d["conv"] = conv
>>> d["relu"] = relu
>>> seq = nn.Sequential(d)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
[27. 27.]]
[[27. 27.]
[27. 27.]]]]
"""
def __init__(self, *args):
"""Initialize Sequential."""
super(Sequential, self).__init__()
self._is_dynamic_name = []
if len(args) == 1:
cells = args[0]
if isinstance(cells, list):
for index, cell in enumerate(cells):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
elif isinstance(cells, OrderedDict):
for name, cell in cells.items():
self.insert_child_to_cell(name, cell)
cell.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)
elif isinstance(cells, Module):
for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
else:
raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, "
f"but got {type(cells).__name__}")
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
self.cell_list = list(self._cells.values())

def __getitem__(self, index):
if isinstance(index, slice):
return self.__class__(
OrderedDict(list(self._cells.items())[index]))
if isinstance(index, Tensor):
index = int(index)
index = _valid_index(len(self), index, self.__class__.__name__)
return list(self._cells.values())[index]

def __setitem__(self, index, module):
if isinstance(index, Tensor):
index = int(index)
cls_name = self.__class__.__name__
if _valid_cell(module, cls_name):
prefix, _ = _get_prefix_and_index(self._cells)
index = _valid_index(len(self), index, cls_name)
key = list(self._cells.keys())[index]
self._cells[key] = module
module.update_parameters_name(prefix + key + ".")
self.cell_list = list(self._cells.values())

def __delitem__(self, index):
cls_name = self.__class__.__name__
if isinstance(index, int):
index = _valid_index(len(self), index, cls_name)
key = list(self._cells.keys())[index]
del self._cells[key]
del self._is_dynamic_name[index]
elif isinstance(index, slice):
keys = list(self._cells.keys())[index]
for key in keys:
del self._cells[key]
del self._is_dynamic_name[index]
for idx, module in enumerate(args):
self.add_module(str(idx), module)

def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""Get the idx-th item of the iterator."""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError(f'index {idx} is out of range')
idx %= size
return next(islice(iterator, idx, None))

@_copy_to_script_wrapper
def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
raise TypeError(f"For '{cls_name}', the type of index must be int type or slice type, "
f"but got {type(index).__name__}")
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for idx, key in enumerate(self._cells.keys()):
cell = self._cells[key]
if self._is_dynamic_name[idx]:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell
else:
temp_dict[key] = cell
self._cells = temp_dict
self.cell_list = list(self._cells.values())
return self._get_item_by_idx(self._modules.values(), idx)

def __setitem__(self, idx: int, module: Module) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)

def __len__(self):
return len(self._cells)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))

def __bool__(self):
return len(self._cells) != 0
@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)

def __add__(self, other):
def __add__(self, other) -> 'Sequential':
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
self.append(ret, layer)
ret.append(layer)
for layer in other:
self.append(ret, layer)
ret.append(layer)
return ret
else:
raise ValueError('add operator supports only objects '
'of Sequential class, but {} is given.'.format(
str(type(other))))
f'of Sequential class, but {str(type(other))} is given.')

def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v

def __iadd__(self, other):
def __iadd__(self, other) -> Self:
if isinstance(other, Sequential):
offset = len(self)
for i, module in enumerate(other):
@@ -170,13 +165,12 @@ class Sequential(Module):
return self
else:
raise ValueError('add operator supports only objects '
'of Sequential class, but {} is given.'.format(
str(type(other))))
f'of Sequential class, but {str(type(other))} is given.')

def __mul__(self, other):
def __mul__(self, other: int) -> 'Sequential':
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
elif (other <= 0):
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
combined = Sequential()
@@ -187,164 +181,87 @@ class Sequential(Module):
offset += 1
return combined

def __rmul__(self, other):
def __rmul__(self, other: int) -> 'Sequential':
return self.__mul__(other)

def __imul__(self, other):
def __imul__(self, other: int) -> Self:
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
elif (other <= 0):
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
len_original = len(self)
offset = len(self)
for _ in range(other - 1):
for i in range(len_original):
self.add_module(str(i + offset), self._cells[str(i)])
self.add_module(str(i + offset), self._modules[str(i)])
offset += len_original
return self

@_copy_to_script_wrapper
def __dir__(self):
keys = Module.__dir__(self)
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def __iter__(self):
return iter(self._cells.values())

@property
def _modules(self):
return self._cells
@_copy_to_script_wrapper
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
# NB: We can't really type check this function as the type of input
# may change dynamically (as is tested in
# TestScript.test_sequential_intermediary_types). Cannot annotate
# with Any as TorchScript expects a more precise type
def forward(self, input):
for module in self:
input = module(input)
return input

def append(self, module):
"""
Appends a given Module to the end of the list.
def append(self, module: Module) -> 'Sequential':
r"""Append a given module to the end.

Args:
module(Module): The Module to be appended.

Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> bn = nn.BatchNorm2d(2)
>>> relu = nn.ReLU()
>>> seq = nn.Sequential([conv, bn])
>>> seq.append(relu)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[26.999863 26.999863]
[26.999863 26.999863]]
[[26.999863 26.999863]
[26.999863 26.999863]]]]
module (nn.Module): module to append
"""
if _valid_cell(module, self.__class__.__name__):
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(len(self)) + ".")
self._is_dynamic_name.append(True)
self._cells[str(len(self))] = module
self.cell_list = list(self._cells.values())
self.add_module(str(len(self)), module)
return self

def add_module(self, name, module):
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
module.__name__))
elif hasattr(self, name) and name not in self._cells:
raise KeyError("attribute '{}' already exists".format(name))
elif '.' in name:
raise KeyError("module name can't contain \".\", got: {}".format(name))
elif name == '':
raise KeyError("module name can't be empty string \"\"")

if _valid_cell(module, self.__class__.__name__):
module.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)

self._cells[name] = module
self.cell_list = list(self._cells.values())

def forward(self, input):
for cell in self.cell_list:
input = cell(input)
return cast_to_adapter_tensor(input)

def pop(self, key):
v = self[key]
del self[key]
return v
def insert(self, index: int, module: Module) -> 'Sequential':
if not isinstance(module, Module):
raise AssertionError(
f'module should be of type: {Module}')
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError(
f'Index out of range: {index}')
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self

def extend(self, sequential):
def extend(self, sequential) -> 'Sequential':
for layer in sequential:
self.append(layer)
return self

def insert(self, index, module):
"""
Inserts a given Cell before a given index in the list.

Args:
index(int): The Insert index in the CellList.
cell(Cell): The Cell to be inserted.
"""
cls_name = self.__class__.__name__
idx = _valid_index(len(self), index, cls_name)
_valid_cell(module, cls_name)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = f'{prefix}{str(length)}{"."}{".".join(param.name.split(".")[key_index+1:])}'
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = module
if self._auto_prefix:
module.update_parameters_name(prefix + str(idx) + ".")
self.cell_list = list(self._cells.values())
self._is_dynamic_name.insert(index, True)

#_ModuleListBase is similar to ms.nn._CellListBase
class _ModuleListBase:
"""
An interface for base the Module as list.

The sequential Module may be iterated using the construct method using for-in statement.
But there are some scenarios that the construct method built-in does not fit.
For convenience, we provide an interface that indicates the sequential
Module may be interpreted as list of Cells, so it can be accessed using
iterator or subscript when a sequential Module instantiate is accessed
by iterator or subscript, it will be interpreted as a list of Cells.
"""
def __init__(self):
"""Initialize _ModuleListBase."""
self.__cell_as_list__ = True #for ms jit parse

@abstractmethod
def __len__(self):
pass

@abstractmethod
def __getitem__(self, index):
pass
class ModuleList(Module):
r"""Holds submodules in a list.

class ModuleList(_ModuleListBase, Module):
"""
Holds Cells in a list.
ModuleList can be used like a regular Python list, the Cells it contains have been initialized.
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
:class:`~torch.nn.Module` methods.

Args:
modules (iterable, optional): an iterable of modules to add
modules (iterable, optional): an iterable of modules to add

Example::

Examples:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

def forward(self, x):
@@ -353,172 +270,158 @@ class ModuleList(_ModuleListBase, Module):
x = self.linears[i // 2](x) + l(x)
return x
"""
def __init__(self, modules=None):
"""Initialize ModuleList."""
_ModuleListBase.__init__(self)
Module.__init__(self)

_modules: Dict[str, Module] # type: ignore[assignment]

def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
if modules is not None:
self.extend(modules)
self += modules

def __getitem__(self, idx):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
return str(idx)

@_copy_to_script_wrapper
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]

def __setitem__(self, idx: int, module: Module) -> None:
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)

def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
return self.__class__(list(self._cells.values())[idx])
if isinstance(idx, int):
idx = _valid_index(len(self), idx, cls_name)
return self._cells[str(idx)]
raise TypeError(f"For '{cls_name}', the type of 'idx' must be int or slice, "
f"but got {type(idx).__name__}.")

def __setitem__(self, idx, module):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
if not isinstance(idx, int) and _valid_cell(module, cls_name):
raise TypeError(f"For '{cls_name}', the type of 'idx' must be int, "
f"but got {type(idx).__name__}.")
idx = _valid_index(len(self), idx, cls_name)
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(idx) + ".")
self._cells[str(idx)] = module

def __delitem__(self, idx):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
if isinstance(idx, int):
idx = _valid_index(len(self), idx, cls_name)
del self._cells[str(idx)]
elif isinstance(idx, slice):
keys = list(self._cells.keys())[idx]
for key in keys:
del self._cells[key]
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
raise TypeError(f"For '{cls_name}', the type of 'index' must be int or slice, "
f"but got {type(idx).__name__}.")
# adjust orderedDict
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for id, cell in enumerate(self._cells.values()):
if self._auto_prefix:
for _, param in cell.parameters_and_names():
param.name = prefix + str(id) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(id)] = cell
self._cells = temp_dict

def __len__(self):
return len(self._cells)

def __iter__(self):
return iter(self._cells.values())

def __iadd__(self, modules):
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))

@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)

@_copy_to_script_wrapper
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def __iadd__(self, modules: Iterable[Module]) -> Self:
return self.extend(modules)

def __add__(self, other):
def __add__(self, other: Iterable[Module]) -> 'ModuleList':
combined = ModuleList()
for _, module in enumerate(chain(self, other)):
combined.append(module)
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined

def __repr__(self):
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + '()'

start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue

start_end_indices.append([i, i])
repeated_blocks.append(r)

lines = []
main_str = self._get_name() + '('
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr

if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"

local_repr = _addindent(local_repr, 2)
lines.append(local_repr)

main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str

@_copy_to_script_wrapper
def __dir__(self):
keys = super(ModuleList, self).__dir__()
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def pop(self, key):
v = self[key]
del self[key]
return v
def insert(self, index: int, module: Module) -> None:
r"""Insert a given module before a given index in the list.

def insert(self, index, module):
Args:
index (int): index to insert.
module (nn.Module): module to insert
"""
Inserts a given Module before a given index in the list.
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module

def append(self, module: Module) -> 'ModuleList':
r"""Append a given module to the end of the list.

Args:
index(int): The Insert index in the ModuleList.
module(Module): The Module to be inserted.
module (nn.Module): module to append
"""
cls_name = self.__class__.__name__
#TODO: after _valid_index fixed, below code can be remove
if len(self) == 0 and index == 0:
idx = index
else:
idx = _valid_index(len(self), index, cls_name)
_valid_cell(module, cls_name)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = module
if self._auto_prefix:
module.update_parameters_name(prefix + str(idx) + ".")

def extend(self, modules):
"""
Appends Cells from a Python iterable to the end of the list.
self.add_module(str(len(self)), module)
return self

Args:
cells(list): The Cells to be extended.
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v

def extend(self, modules: Iterable[Module]) -> Self:
r"""Append modules from a Python iterable to the end of the list.

Raises:
TypeError: If the argument cells are not a list of Cells.
Args:
modules (iterable): iterable of modules to append
"""
cls_name = self.__class__.__name__
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__)
prefix, _ = _get_prefix_and_index(self._cells)
for module in modules:
if _valid_cell(module, cls_name):
if self._auto_prefix:
module.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = module
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self

def append(self, module):
"""
Appends a given Module to the end of the list.

Args:
module(Module): The subcell to be appended.
"""
if _valid_cell(module, self.__class__.__name__):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = module
# remove forward alltogether to fallback on Module's _forward_unimplemented

def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)

class ModuleDict(Module):
r"""Holds submodules in a dictionary.

:class:`nn.ModuleDict` can be indexed like a regular Python dictionary,
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`nn.Module` methods.
:class:`~torch.nn.Module` methods.

:class:`nn.ModuleDict` is an **ordered** dictionary that respects
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects

* the order of insertion, and

* in :meth:`nn.ModuleDict.update`, the order of the merged
* in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`nn.ModuleDict` (the argument to
:meth:`nn.ModuleDict.update`).
:class:`~torch.nn.ModuleDict` (the argument to
:meth:`~torch.nn.ModuleDict.update`).

Note that :meth:`nn.ModuleDict.update` with other unordered mapping
Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.

@@ -530,7 +433,7 @@ class ModuleDict(Module):

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
super().__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
@@ -546,42 +449,40 @@ class ModuleDict(Module):
return x
"""

def __init__(self, modules=None):
super(ModuleDict, self).__init__()
self.__cell_as_dict__ = True
_modules: Dict[str, Module] # type: ignore[assignment]

def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)

def __getitem__(self, key):
return self._cells[key]
@_copy_to_script_wrapper
def __getitem__(self, key: str) -> Module:
return self._modules[key]

def __setitem__(self, key, module):
self._update_cell_para_name(key, module)
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)

def __delitem__(self, key):
del self._cells[key]
def __delitem__(self, key: str) -> None:
del self._modules[key]

def __len__(self):
return len(self._cells)
@_copy_to_script_wrapper
def __len__(self) -> int:
return len(self._modules)

def __iter__(self):
return iter(self._cells)
@_copy_to_script_wrapper
def __iter__(self) -> Iterator[str]:
return iter(self._modules)

def __contains__(self, key):
return key in self._cells
@_copy_to_script_wrapper
def __contains__(self, key: str) -> bool:
return key in self._modules

def _update_cell_para_name(self, key, cell):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + key + ".")
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()

def clear(self):
"""Remove all items from the ModuleDict.
"""
self._cells.clear()

def pop(self, key):
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.

Args:
@@ -591,32 +492,31 @@ class ModuleDict(Module):
del self[key]
return v

def keys(self):
r"""Return an iterable of the ModuleDict keys.
"""
return self._cells.keys()
@_copy_to_script_wrapper
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()

def items(self):
r"""Return an iterable of the ModuleDict key/value pairs.
"""
return self._cells.items()
@_copy_to_script_wrapper
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()

def values(self):
r"""Return an iterable of the ModuleDict values.
"""
return self._cells.values()
@_copy_to_script_wrapper
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()

def update(self, modules):
r"""Update the :class:`nn.ModuleDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.

.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`nn.ModuleDict`, or
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.

Args:
modules (iterable): a mapping (dictionary) from string to :class:`nn.Module`,
or an iterable of key-value pairs of type (string, :class:`nn.Module`)
modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
@@ -645,15 +545,15 @@ class ModuleDict(Module):


class ParameterList(Module):
"""Holds parameters in a list.
r"""Holds parameters in a list.

:class:`nn.ParameterList` can be used like a regular Python
list, but Tensors that are :class:`nn.Parameter` are properly registered,
and will be visible by all :class:`nn.Module` methods.
:class:`~torch.nn.ParameterList` can be used like a regular Python
list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
and will be visible by all :class:`~torch.nn.Module` methods.

Note that the constructor, assigning an element of the list, the
:meth:`nn.ParameterDict.append` method and the :meth:`nn.ParameterDict.extend`
method will convert any :class:`Tensor` into :class:`nn.Parameter`.
:meth:`~torch.nn.ParameterDict.append` method and the :meth:`~torch.nn.ParameterDict.extend`
method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.

Args:
parameters (iterable, optional): an iterable of elements to add to the list.
@@ -662,8 +562,8 @@ class ParameterList(Module):

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.params = nn.ParameterList([nn.Parameter(ms_torch.randn(10, 10)) for i in range(10)])
super().__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
@@ -672,21 +572,29 @@ class ParameterList(Module):
return x
"""

def __init__(self, values=None):
super(ParameterList, self).__init__()
def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
super().__init__()
self._size = 0
if values is not None:
self += values

def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules"""
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not -len(self) <= idx < len(self):
raise IndexError('index {} is out of range'.format(idx))
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
return str(idx)

@overload
def __getitem__(self, idx: int) -> Any:
...

@overload
def __getitem__(self: T, idx: slice) -> T:
...

def __getitem__(self, idx):
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
@@ -698,33 +606,33 @@ class ParameterList(Module):
idx = self._get_abs_string_index(idx)
return getattr(self, str(idx))

def __setitem__(self, idx, param):
def __setitem__(self, idx: int, param: Any) -> None:
# Note that all other function that add an entry to the list part of
# the ParameterList end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the list part and thus won't
# call into this function.
idx = self._get_abs_string_index(idx)
if isinstance(param, Tensor) and not isinstance(param, Parameter):
if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
param = Parameter(param)
return setattr(self, str(idx), param)

def __len__(self):
def __len__(self) -> int:
return self._size

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
return iter(self[i] for i in range(len(self)))

def __iadd__(self, parameters):
def __iadd__(self, parameters: Iterable[Any]) -> Self:
return self.extend(parameters)

def __dir__(self):
keys = super(ParameterList, self).__dir__()
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def append(self, value):
"""Appends a given value at the end of the list.
def append(self, value: Any) -> 'ParameterList':
"""Append a given value at the end of the list.

Args:
value (Any): value to append
@@ -734,26 +642,29 @@ class ParameterList(Module):
self[new_idx] = value
return self

def extend(self, values):
"""Appends values from a Python iterable to the end of the list.
def extend(self, values: Iterable[Any]) -> Self:
"""Append values from a Python iterable to the end of the list.

Args:
values (iterable): iterable of values to append
"""
# Tensor is an iterable but we never want to unpack it here
if not isinstance(values, container_abcs.Iterable) or isinstance(values, Tensor):
if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(values).__name__)
for value in values:
self.append(value)
return self

def extra_repr(self):
def extra_repr(self) -> str:
child_lines = []
for k, p in enumerate(self):
if isinstance(p, Tensor):
if isinstance(p, torch.Tensor):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
device_str = f' ({p.device})'
else:
device_str = ''
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
p.dtype, size_str, device_str)
@@ -767,31 +678,23 @@ class ParameterList(Module):
def __call__(self, *args, **kwargs):
raise RuntimeError('ParameterList should not be called.')

# adpater api, to convert ParameterList to list[Parameter]
def to_list(self):
list_params = []
for i, p in enumerate(self):
p.name = str(i) + "." + p.name
list_params.append(p)
return list_params


class ParameterDict(Module):
"""Holds parameters in a dictionary.
r"""Holds parameters in a dictionary.

ParameterDict can be indexed like a regular Python dictionary, but Parameters it
contains are properly registered, and will be visible by all Module methods.
Other objects are treated as would be done by a regular Python dictionary

:class:`nn.ParameterDict` is an **ordered** dictionary.
:meth:`nn.ParameterDict.update` with other unordered mapping
:class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
:meth:`~torch.nn.ParameterDict.update` with other unordered mapping
types (e.g., Python's plain ``dict``) does not preserve the order of the
merged mapping. On the other hand, ``OrderedDict`` or another :class:`nn.ParameterDict`
merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
will preserve their ordering.

Note that the constructor, assigning an element of the dictionary and the
:meth:`nn.ParameterDict.update` method will convert any :class:`Tensor` into
:class:`nn.Parameter`.
:meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
:class:`~torch.nn.Parameter`.

Args:
values (iterable, optional): a mapping (dictionary) of
@@ -802,10 +705,10 @@ class ParameterDict(Module):

class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
super().__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(ms_torch.randn(5, 10)),
'right': nn.Parameter(ms_torch.randn(5, 10))
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})

def forward(self, x, choice):
@@ -813,13 +716,13 @@ class ParameterDict(Module):
return x
"""

def __init__(self, parameters = None):
super(ParameterDict, self).__init__()
def __init__(self, parameters: Any = None) -> None:
super().__init__()
self._keys: Dict[str, None] = {}
if parameters is not None:
self.update(parameters)

def _key_to_attr(self, key):
def _key_to_attr(self, key: str) -> str:
if not isinstance(key, str):
raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
f"not a string (type is '{type(key).__name__}'). Open an issue on "
@@ -828,11 +731,11 @@ class ParameterDict(Module):
# Use the key as-is so that `.named_parameters()` returns the right thing
return key

def __getitem__(self, key):
def __getitem__(self, key: str) -> Any:
attr = self._key_to_attr(key)
return getattr(self, attr)

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
# Note that all other function that add an entry to the dictionary part of
# the ParameterDict end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
@@ -840,36 +743,37 @@ class ParameterDict(Module):
# call into this function.
self._keys[key] = None
attr = self._key_to_attr(key)
if isinstance(value, Tensor) and not isinstance(value, Parameter):
if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
value = Parameter(value)
setattr(self, attr, value)

def __delitem__(self, key):
def __delitem__(self, key: str) -> None:
del self._keys[key]
attr = self._key_to_attr(key)
delattr(self, attr)

def __len__(self):
def __len__(self) -> int:
return len(self._keys)

def __iter__(self):
def __iter__(self) -> Iterator[str]:
return iter(self._keys)

def __reversed__(self):
def __reversed__(self) -> Iterator[str]:
return reversed(list(self._keys))

def copy(self):
"""Returns a copy of this :class:`nn.ParameterDict` instance.
"""
def copy(self) -> 'ParameterDict':
"""Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
# We have to use an OrderedDict because the ParameterDict constructor
# behaves differently on plain dict vs OrderedDict
return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))

def __contains__(self, key):
def __contains__(self, key: str) -> bool:
return key in self._keys

def setdefault(self, key, default = None):
"""If key is in the ParameterDict, return its value.
def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
"""Set the default for a key in the Parameterdict.

If key is in the ParameterDict, return its value.
If not, insert `key` with a parameter `default` and return `default`.
`default` defaults to `None`.

@@ -877,18 +781,16 @@ class ParameterDict(Module):
key (str): key to set default for
default (Any): the parameter set to the key
"""

if key not in self:
self[key] = default
return self[key]

def clear(self):
"""Remove all items from the ParameterDict.
"""
def clear(self) -> None:
"""Remove all items from the ParameterDict."""
for k in self._keys.copy():
del self[k]

def pop(self, key):
def pop(self, key: str) -> Any:
r"""Remove key from the ParameterDict and return its parameter.

Args:
@@ -898,10 +800,8 @@ class ParameterDict(Module):
del self[key]
return v

def popitem(self):
"""Remove and return the last inserted `(key, parameter)` pair
from the ParameterDict
"""
def popitem(self) -> Tuple[str, Any]:
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
k, _ = self._keys.popitem()
# We need the key in the _keys to be able to access/del
self._keys[k] = None
@@ -909,9 +809,8 @@ class ParameterDict(Module):
del self[k]
return k, val

def get(self, key, default = None):
r"""Return the parameter associated with key if present.
Otherwise return default if provided, None if not.
def get(self, key: str, default: Optional[Any] = None) -> Any:
r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.

Args:
key (str): key to get from the ParameterDict
@@ -919,42 +818,38 @@ class ParameterDict(Module):
"""
return self[key] if key in self else default

def fromkeys(self, keys, default = None):
r"""Return a new ParameterDict with the keys provided
def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict':
r"""Return a new ParameterDict with the keys provided.

Args:
keys (iterable, string): keys to make the new ParameterDict from
default (Parameter, optional): value to set for all keys
"""
return ParameterDict(((k, default) for k in keys))
return ParameterDict((k, default) for k in keys)

def keys(self):
r"""Return an iterable of the ParameterDict keys.
"""
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ParameterDict keys."""
return self._keys.keys()

def items(self):
r"""Return an iterable of the ParameterDict key/value pairs.
"""
def items(self) -> Iterable[Tuple[str, Any]]:
r"""Return an iterable of the ParameterDict key/value pairs."""
return ((k, self[k]) for k in self._keys)

def values(self):
r"""Return an iterable of the ParameterDict values.
"""
def values(self) -> Iterable[Any]:
r"""Return an iterable of the ParameterDict values."""
return (self[k] for k in self._keys)

def update(self, parameters):
r"""Update the :class:`~nn.ParameterDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.

.. note::
If :attr:`parameters` is an ``OrderedDict``, a :class:`~nn.ParameterDict`, or
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.

Args:
parameters (iterable): a mapping (dictionary) from string to
:class:`~nn.Parameter`, or an iterable of
key-value pairs of type (string, :class:`~nn.Parameter`)
:class:`~torch.nn.Parameter`, or an iterable of
key-value pairs of type (string, :class:`~torch.nn.Parameter`)
"""
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParametersDict.update should be called with an "
@@ -980,15 +875,18 @@ class ParameterDict(Module):
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
self[p[0]] = p[1] # type: ignore[assignment]

def extra_repr(self):
def extra_repr(self) -> str:
child_lines = []
for k, p in self.items():
if isinstance(p, Tensor):
if isinstance(p, torch.Tensor):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
device_str = f' ({p.device})'
else:
device_str = ''
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
typename(p), size_str, device_str)
torch.typename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
else:
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
@@ -998,22 +896,16 @@ class ParameterDict(Module):
def __call__(self, input):
raise RuntimeError('ParameterDict should not be called.')

def __or__(self, other):
def __or__(self, other: 'ParameterDict') -> 'ParameterDict':
copy = self.copy()
copy.update(other)
return copy

def __ror__(self, other):
def __ror__(self, other: 'ParameterDict') -> 'ParameterDict':
copy = other.copy()
copy.update(self)
return copy

def __ior__(self, other):
def __ior__(self, other : 'ParameterDict') -> Self:
self.update(other)
return self

def to_dict(self):
new_dict = {}
for key in self._keys:
new_dict[key] = self[key]
return new_dict

+ 4
- 26
mindtorch/torch/nn/modules/linear.py View File

@@ -4,7 +4,7 @@
import math
import mindspore.ops as P
from mindtorch.torch.nn import init
from mindtorch.torch.nn.functional import linear
from mindtorch.torch.nn.functional import linear, bilinear
from mindtorch.torch.functional import empty
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.utils import unsupported_attr
@@ -96,43 +96,21 @@ class Bilinear(Module):
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.matmul = P.MatMul()
self.mul = P.Mul()
self.tile = P.Tile()
self.reducesum = P.ReduceSum()

self.has_bias = False
self.weight = Parameter(empty((self.out_features, self.in1_features, self.in2_features),
dtype=dtype, device=device), requires_grad=True)
self.bias = None
if bias:
self.bias_add = P.BiasAdd()
self.bias = Parameter(empty(self.out_features, dtype=dtype, device=device), requires_grad=True)
self.has_bias = True
self.reset_parameters()

def reset_parameters(self):
bound = 1 / math.sqrt(self.weight.shape[1])
init.uniform_(self.weight, -bound, bound)
if self.has_bias:
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)

def forward(self, input1, input2):
input1 = cast_to_ms_tensor(input1)
input2 = cast_to_ms_tensor(input2)
input1_shape = input1.shape
input2_shape = input2.shape
if len(input1_shape) != 2:
input1 = input1.reshape((-1, input1_shape[-1]))
x = self.matmul(input1, self.weight.permute(1, 0, 2).reshape(self.weight.shape[1], -1))
if len(input2_shape) != 2:
input2 = input2.reshape((-1, input2_shape[-1]))
x = self.mul(x, self.tile(input2, (1, self.out_features)))
x = x.reshape(x.shape[0], self.out_features, -1)
x = self.reducesum(x, -1)
if self.has_bias:
x = self.bias_add(x, self.bias)
x = x.reshape(*input1_shape[:-1], -1)
return cast_to_adapter_tensor(x)
return bilinear(input1, input2, self.weight, self.bias)

def extra_repr(self):
return 'in1_features={}, in2_features={}, out_features={}, bias={}'.format(


+ 1976
- 780
mindtorch/torch/nn/modules/module.py
File diff suppressed because it is too large
View File


+ 29
- 199
mindtorch/torch/nn/parameter.py View File

@@ -11,10 +11,11 @@ from mindspore.common import dtype as mstype
from mindspore._c_expression import Tensor as Tensor_
from mindspore.parallel._ps_context import _is_role_worker, _clone_hash_table
from mindspore.parallel._ps_context import _insert_accumu_init_info
from mindtorch.torch.functional import empty
from mindtorch.torch.tensor import Tensor, cast_to_ms_tensor, cast_to_adapter_tensor
from mindtorch.torch.common.dtype import _msdtype2typeDict

__all__ = ['Parameter', 'ParameterTuple']
__all__ = ['Parameter']

def init_to_value(init):
"""
@@ -36,206 +37,35 @@ def init_to_value(init):
return float(init)
raise ValueError("The argument 'init' should be number or string, but got {}.".format(type(init)))

class Parameter(ms.Parameter):
class Parameter(Tensor):
_base_type = {}
def __new__(cls, data, *args, **kwargs):
init_data_flag = bool(isinstance(data, ms.Tensor) and data.has_init)
rc = sys.getrefcount(data)
input_class, *class_init_args = Parameter._get_parameter_new_args(data, rc)
new_type = Parameter._get_base_class(input_class)
obj = input_class.__new__(new_type)
input_class.__init__(obj, *class_init_args)
obj.init_mode = None
obj.is_default_input_init = init_data_flag
if obj.has_init:
obj.init_mode = data
return obj
def __reduce_ex__(self, _):
data = self
if self.init_mode is not None:
data = self.init_mode
is_leaf = True
retains_grad = False
# def __reduce_ex__(self, _):
# data = self
# if self.init_mode is not None:
# data = self.init_mode
# else:
# # cast to break deep infinite loop while deepcopy
# data = ms.Tensor(self)
# return (
# Parameter, (data, self.requires_grad, self.name, self.layerwise_parallel))
def __init__(self, data, requires_grad=True):
# self.adapter_flag = True
if isinstance(data, Tensor):
super().__init__(data, requires_grad=requires_grad, cast_tensor=True)
else:
# cast to break deep infinite loop while deepcopy
data = ms.Tensor(self)
return (
Parameter, (data, self.requires_grad, self.name, self.layerwise_parallel))

def __init__(self, data, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True):
self.adapter_flag = True
super().__init__(default_input=data, name=name, requires_grad=requires_grad,
layerwise_parallel=layerwise_parallel, parallel_optimizer=parallel_optimizer)

def __deepcopy__(self, memodict):
new_obj = Parameter(self)
new_obj.name = self.name
new_obj._inited_param = self._inited_param
return new_obj
raise ValueError(f'not support type {type(data)}.')

def __str__(self):
if self.init_finished:
Tensor_.data_sync(self.data, True)
return f'Parameter containing: {Tensor_.__repr__(self.data)}, requires_grad={self.requires_grad})'

@staticmethod
def _get_base_class(input_class):
input_class_name = Parameter.__name__
if input_class_name in Parameter._base_type:
new_type = Parameter._base_type.get(input_class_name)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
new_type = type(input_class_name, (Parameter, input_class), {})
Parameter._base_type[input_class_name] = new_type
return new_type

@property
def dtype(self):
dtype = super(Parameter, self).dtype
return _msdtype2typeDict.get(str(dtype), dtype)

@property
def data(self):
"""Return the parameter object."""
return self

@data.setter
def data(self, data):
ms_data = cast_to_ms_tensor(data)
self.set_data(ms_data, True)

def _update_tensor_data(self, data):
"""Update the parameter by a Tensor."""
if isinstance(self, ms.Tensor):
self.init_flag = False
self.init = None
return self.assign_value(data)
new_param = Parameter(data, self.name, self.requires_grad)
new_param.param_info = self.param_info
return new_param

@staticmethod
def _from_tensor(tensor, *args, **kwargs):
"""Create a `Parameter` that data is shared from a `Tensor`."""
if not isinstance(tensor, Tensor_):
raise TypeError(f"The type of input must be Tensor, but got {type(tensor)}.")
param = Tensor_.__new__(Parameter)
Tensor_.__init__(param, tensor)
param.init = None
param.init_mode = None
param.is_default_input_init = False
Parameter.__init__(param, tensor, *args, **kwargs)
return param

def requires_grad_(self, requires_grad=True):
self.requires_grad = requires_grad
return self

def detach(self):
return cast_to_adapter_tensor(ms.Parameter.value(self))

def numel(self):
shape = self.shape
return reduce((lambda x, y: x * y), shape) if shape else 1

def nelement(self):
return self.numel()

def item(self):
if self.numel() > 1:
raise ValueError("only one element tensors can be converted to Python scalars")
output = self.asnumpy().reshape(-1).tolist()
return output[0]

def stride(self, dim=None):
bytelen = self.itemsize
output = list(self.strides)
for i in range(len(output)):
output[i] = output[i]//bytelen
output = tuple(output)
if dim is not None:
output = output[dim]
return output

def is_signed(self):
return self.dtype in mstype.signed_type

def is_complex(self):
return self.dtype in mstype.complex_type

def is_floating_point(self):
return self.dtype in [mstype.float32, mstype.float16, mstype.float64]


def _init_parameter_api():
param_func = dir(Parameter)
tensor_dict = Tensor.__dict__

for attr in tensor_dict:
if attr not in param_func:
func = inspect.getattr_static(Tensor, attr)
setattr(Parameter, attr, func)

_init_parameter_api()


class ParameterTuple(tuple):
"""
Inherited from tuple, ParameterTuple is used to save multiple parameter.

Note:
It is used to store the parameters of the network into the parameter tuple collection.
"""
def __new__(cls, iterable):
"""Create instance object of ParameterTuple."""
data = tuple(iterable)
ids = set()
names = set()
for x in data:
if not isinstance(x, Parameter):
raise TypeError(f"For ParameterTuple initialization, "
f"ParameterTuple input should be 'Parameter' collection, "
f"but got a {type(iterable)}. ")
if id(x) not in ids:
if x.name in names:
raise ValueError("The value {} , its name '{}' already exists. "
"Please set a unique name for the parameter.".format(x, x.name))
names.add(x.name)
ids.add(id(x))
return tuple.__new__(ParameterTuple, tuple(data))

def clone(self, prefix, init='same'):
"""
Clone the parameters in ParameterTuple element-wisely to generate a new ParameterTuple.

Args:
prefix (str): Namespace of parameter, the prefix string will be added to the names of parameters
in parametertuple.

init (Union[Tensor, str, numbers.Number]): Clone the shape and dtype of Parameters in ParameterTuple and
set data according to `init`. Default: 'same'.
If `init` is a `Tensor` , set the new Parameter data to the input Tensor.
If `init` is `numbers.Number` , set the new Parameter data to the input number.
If `init` is a `str`, data will be seted according to the initialization method of the same name in
the `Initializer`.
If `init` is 'same', the new Parameter has the same value with the original Parameter.


Returns:
Tuple, the new Parameter tuple.
"""
validator.check_str_by_regular(prefix)
new = []
for x in self:
x1 = x.clone(init)
x1.name = prefix + "." + x1.name
new.append(x1)

if not x1.cache_enable:
continue

if _is_role_worker():
_clone_hash_table(x.name, x.key, x1.name, x1.key)
_insert_accumu_init_info(x1.name, init_to_value(init))
return ParameterTuple(new)
result = type(self)(self.tensor.copy())

def __parameter_tuple__(self):
"""For parse check."""
def __repr__(self):
# if self.init_finished:
# Tensor_.data_sync(self.data, True)
return f'Parameter containing: {self.data}, requires_grad={self.requires_grad})'

+ 2
- 2
mindtorch/torch/optim/__init__.py View File

@@ -2,8 +2,8 @@
# -*- coding: utf-8 -*-
from mindtorch.torch.optim.optimizer import Optimizer
from mindtorch.torch.optim.sgd import SGD
from mindtorch.torch.optim.adam import Adam
# from mindtorch.torch.optim.adam import Adam
from mindtorch.torch.optim.adamw import AdamW
from mindtorch.torch.optim import lr_scheduler

__all__ = ['Optimizer', 'SGD', 'Adam', 'AdamW']
__all__ = ['Optimizer', 'SGD']#, 'Adam', 'AdamW']

+ 677
- 19
mindtorch/torch/optim/adamw.py View File

@@ -1,26 +1,684 @@
from mindspore.experimental.optim import AdamW as AdamW_MS
from mindtorch.torch.optim.optimizer import _Optimizer, _is_tensor
from mindtorch.torch.tensor import tensor
from mindtorch import torch
from mindtorch.torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable,
_get_scalar_dtype, _capturable_doc, _differentiable_doc,
_foreach_doc, _fused_doc, _maximize_doc,
ParamsT, _view_as_real)
from typing import List, Optional, Tuple, Union

class AdamW(_Optimizer, AdamW_MS):
def __init__(self, *args, **kwargs):
AdamW_MS.__init__(self, *args, **kwargs)
_Optimizer.__init__(self)
__all__ = ["AdamW", "adamw"]


class AdamW(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
*,
maximize: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if isinstance(lr, Tensor) and foreach and not capturable:
raise ValueError("lr as a Tensor is not supported for capturable=False and foreach=True")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
foreach=foreach,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
fused=fused,
)
super().__init__(params, defaults)

# if fused:
# if differentiable:
# raise RuntimeError("`fused` does not support `differentiable`")
# self._step_supports_amp_scaling = True
# # TODO(crcrpar): [low prec params & their higher prec copy]
# # Suppor AMP with FP16/BF16 model params which would need
# # higher prec copy of params to do update math in higher prec to
# # alleviate the loss of information.
# fused_supported_devices = _get_fused_kernels_supported_devices()
# if not all(
# p.device.type in fused_supported_devices and
# torch.is_floating_point(p)
# for pg in self.param_groups for p in pg['params']
# ):
# raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
# f"supported devices: {fused_supported_devices}.")
# if foreach:
# raise RuntimeError("`fused` and `foreach` cannot be `True` together.")

def __setstate__(self, state):
_Optimizer.__setstate__(self, state)
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
group.setdefault('maximize', False)
group.setdefault("amsgrad", False)
group.setdefault("maximize", False)
group.setdefault("foreach", None)
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
fused = group.setdefault("fused", None)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state["step"] = torch.tensor(float(p_state["step"]), dtype=_get_scalar_dtype(is_fused=fused))

def _init_group(
self,
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
):
has_complex = False
for p in group["params"]:
if p.grad is None:
continue
has_complex |= torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]

# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = (
torch.zeros((), dtype=_get_scalar_dtype(is_fused=group["fused"]), device=p.device)
if group["capturable"] or group["fused"]
else torch.tensor(0.0, dtype=_get_scalar_dtype())
)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p#, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p#, memory_format=torch.preserve_format
)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(
p#, memory_format=torch.preserve_format
)

exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])

if group['amsgrad']:
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
if group['differentiable'] and state['step'].requires_grad:
raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode')

# Foreach without capturable does not support a tensor lr
if group['foreach'] and isinstance(group['lr'], Tensor) and not group['capturable']:
raise RuntimeError('lr as a Tensor is not supported for capturable=False and foreach=True')

state_steps.append(state["step"])
return has_complex

@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.

Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
# self._cuda_graph_capture_health_check()

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group["amsgrad"]
beta1, beta2 = group["betas"]

has_complex = self._init_group(
group,
params_with_grad,
grads,
amsgrad,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)

adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
differentiable=group["differentiable"],
fused=group["fused"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
has_complex=has_complex,
)

return loss


AdamW.__doc__ = r"""Implements AdamW algorithm.

.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
\text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
\: \epsilon \text{ (epsilon)} \\
&\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
\: \textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
\text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\

&\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
&\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
&\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
&\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\textbf{if} \: amsgrad \\
&\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
\widehat{v_t}) \\
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
&\hspace{5mm}\textbf{else} \\
&\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}

For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
""" + fr"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
is not yet supported for all our implementations. Please use a float
LR if you are not also specifying fused=True or capturable=True.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (bool, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
{_maximize_doc}
{_foreach_doc}
{_capturable_doc}
{_differentiable_doc}
{_fused_doc}
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ

"""


def adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
fused: Optional[bool] = None,
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
has_complex: bool = False,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
):
r"""Functional API that performs AdamW algorithm computation.

See :class:`~torch.optim.AdamW` for details.
"""
# if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
# raise RuntimeError(
# "API has changed, `state_steps` argument must contain a list of singleton tensors"
# )

# Respect when the user inputs False/True for foreach or fused. We only want to change
# the default when neither have been user-specified. Note that we default to foreach
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
# bake-in time before making it the default, even if it is typically faster.
# if fused is None and foreach is None:
# _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
# # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
# if foreach and isinstance(lr, Tensor) and not capturable:
# foreach = False
# if fused is None:
# fused = False
# if foreach is None:
# foreach = False

# if foreach and torch.jit.is_scripting():
# raise RuntimeError("torch.jit.script not supported with foreach optimizers")
# if fused and torch.jit.is_scripting():
# raise RuntimeError("torch.jit.script not supported with fused optimizers")

# if fused and not torch.jit.is_scripting():
# func = _fused_adamw
# elif foreach and not torch.jit.is_scripting():
# func = _multi_tensor_adamw
# else:
func = _single_tensor_adamw

func(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,
grad_scale=grad_scale,
found_inf=found_inf,
has_complex=has_complex,
)


def _single_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):

assert grad_scale is None and found_inf is None

# if torch.jit.is_scripting():
# # this assert is due to JIT being dumb and not realizing that the ops below
# # have overloads to handle both float and Tensor lrs, so we just assert it's
# # a float since most people using JIT are using floats
# assert isinstance(lr, float)

for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]

# # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
# if not torch._utils.is_compiling() and capturable:
# assert (
# (param.is_cuda and step_t.is_cuda) or (param.is_xla and step_t.is_xla)
# ), "If capturable=True, params and state_steps must be CUDA or XLA tensors."

if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
param = torch.view_as_real(param)

# update step
step_t += 1

# Perform stepweight decay
param.mul_(torch.tensor(1 - lr * weight_decay))

# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

if capturable or differentiable:
step = step_t

bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step

step_size = lr / bias_correction1
step_size_neg = step_size.neg()

bias_correction2_sqrt = bias_correction2.sqrt()

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
if differentiable:
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
else:
max_exp_avg_sq = max_exp_avg_sqs[i]

max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))

# Uses the max. for normalizing running avg. of gradient
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
denom = (
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)
else:
denom = (
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
).add_(eps / step_size_neg)

param.addcdiv_(exp_avg, denom)
# else:
# step = _get_value(step_t)

# bias_correction1 = 1 - beta1 ** step
# bias_correction2 = 1 - beta2 ** step

# step_size = lr / bias_correction1

# bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

# if amsgrad:
# # Maintains the maximum of all 2nd moment running avg. till now
# torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])

# # Use the max. for normalizing running avg. of gradient
# denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
# else:
# denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

# param.addcdiv_(exp_avg, denom, value=-step_size)

# Lastly, switch back to complex view
if amsgrad and torch.is_complex(params[i]):
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])


def _multi_tensor_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool,
differentiable: bool,
has_complex: bool,
):
if len(params) == 0:
return

if isinstance(lr, Tensor) and not capturable:
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
), "If capturable=True, params and state_steps must be CUDA tensors."

assert not differentiable, "_foreach ops don't support autograd"

assert grad_scale is None and found_inf is None

grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for ((
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
device_state_steps,
), _) in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(device_grads)

if has_complex:
if amsgrad:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
else:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if device_state_steps[0].is_cpu:
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device='cpu'), alpha=1.0)
else:
torch._foreach_add_(device_state_steps, 1)

# Perform stepweight decay
if weight_decay != 0:
torch._foreach_mul_(device_params, 1 - lr * weight_decay)

# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)

torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)

# Delete the local intermediate since it won't be used anymore to save on peak memory
del device_grads

if capturable:
bias_correction1 = torch._foreach_pow(beta1, device_state_steps)
bias_correction2 = torch._foreach_pow(beta2, device_state_steps)
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
# we do not negate bias_correction1 as it'll need to be negated later anyway
torch._foreach_neg_(bias_correction2)

# foreach_div doesn't allow a scalar as the first arg
torch._foreach_div_(bias_correction1, lr)
torch._foreach_reciprocal_(bias_correction1)

torch._foreach_sqrt_(bias_correction2)

# Re-assign for clarity as we maintain minimal intermediates: we'll have
# step_size = - lr / (1 - beta1 ^ t) where t = num_steps
# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
step_size = bias_correction1
bias_correction2_sqrt = bias_correction2

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)

# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)

torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_div_(exp_avg_sq_sqrt, step_size)

# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
else:
bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]

step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])

bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)

# Use the max. for normalizing running avg. of gradient
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)

torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size)


def _fused_adamw(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
max_exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
capturable: bool, # Needed for consistency.
differentiable: bool,
has_complex: bool,
) -> None:
if not params:
return
if differentiable:
raise RuntimeError("Adam with fused=True does not support differentiable=True")

state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and _is_tensor(state_values[0]['step'])
if not step_is_tensor:
for s in state_values:
s['step'] = tensor(float(s['step']))
grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None

def state_dict(self):
return super()._ms_state_dict('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', 'state_step')
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
# treating it as a scalar.
lr_dict = {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None

def load_state_dict(self, state_dict):
return super()._ms_load_state_dict(state_dict, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', 'state_step')
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device, _), ((device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
device_state_steps,), _) in grouped_tensors.items():
device_grad_scale, device_found_inf = None, None
if grad_scale is not None:
if device not in grad_scale_dict:
grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
device_grad_scale = grad_scale_dict[device]
if found_inf is not None:
if found_inf not in found_inf_dict:
found_inf_dict[device] = found_inf.to(device, non_blocking=True)
device_found_inf = found_inf_dict[device]
if lr_dict is not None and device not in lr_dict:
lr_dict[device] = lr.to(device=device, non_blocking=True)
lr = lr_dict[device]
torch._foreach_add_(device_state_steps, 1)
torch._fused_adamw_(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_max_exp_avg_sqs,
device_state_steps,
amsgrad=amsgrad,
lr=lr,
beta1=beta1,
beta2=beta2,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=device_grad_scale,
found_inf=device_found_inf,
)
if device_found_inf is not None:
torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))

+ 701
- 154
mindtorch/torch/optim/optimizer.py View File

@@ -1,95 +1,495 @@
import abc
import math
import functools
import warnings
from collections import OrderedDict, defaultdict
from collections.abc import Iterable
from copy import deepcopy
from itertools import chain
import mindspore as ms
from mindspore.experimental.optim import Optimizer as Optimizer_MS
from mindtorch.torch.tensor import Tensor, tensor, cast_to_ms_tensor
from mindtorch.utils import unsupported_attr

class _RequiredParameter():
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import ParamSpec, Self, TypeAlias

from mindtorch import torch
import mindtorch.torch.utils.hooks as hooks
from mindtorch.torch.utils.hooks import RemovableHandle
# from mindtorch.torch.utils._foreach_utils import (
# Indices,
# TensorListList,
# _get_foreach_kernels_supported_devices,
# _get_fused_kernels_supported_devices,
# )
# from mindtorch.torch._utils import is_compiling
# from mindtorch.torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

Args: TypeAlias = Tuple[Any, ...]
Kwargs: TypeAlias = Dict[str, Any]
StateDict: TypeAlias = Dict[str, Any]

GlobalOptimizerPreHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]]
GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]

__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]

class _RequiredParameter:
"""Singleton class representing a required parameter for an Optimizer."""
def __repr__(self):
def __repr__(self) -> str:
return "<required parameter>"

required = _RequiredParameter()

class _Optimizer:
def __init__(self):
self._optimizer_step_pre_hooks=OrderedDict()
self._optimizer_step_post_hooks=OrderedDict()

self._patch_step_function()
def _use_grad_for_differentiable(func):
def _use_grad(self, *args, **kwargs):
import torch._dynamo
prev_grad = torch.is_grad_enabled()
try:
# Note on graph break below:
# we need to graph break to ensure that aot respects the no_grad annotation.
# This is important for perf because without this, functionalization will generate an epilogue
# which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
# inductor will allocate for every parameter in the model, which is horrible.
# With this, aot correctly sees that this is an inference graph, and functionalization will generate
# an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
# step is in place and is able to avoid the extra allocation.
# In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
# or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
# graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
# see https://github.com/pytorch/pytorch/issues/104053
torch.set_grad_enabled(self.defaults['differentiable'])
torch._dynamo.graph_break()
ret = func(self, *args, **kwargs)
finally:
torch._dynamo.graph_break()
torch.set_grad_enabled(prev_grad)
return ret
functools.update_wrapper(_use_grad, func)
return _use_grad

def _view_as_real(params, *state_and_grads):
for i, p in enumerate(params):
if torch.is_complex(p):
params[i] = torch.view_as_real(params[i])
for s in state_and_grads:
s[i] = torch.view_as_real(s[i])

def _get_scalar_dtype(is_fused=None):
if is_fused:
return torch.float32
return torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32

# Common doc strings among optimizers
_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
is used. If unspecified by the user (so foreach is None), we will try to use
foreach over the for-loop implementation on CUDA, since it is usually
significantly more performant. Note that the foreach implementation uses
~ sizeof(params) more peak memory than the for-loop version due to the intermediates
being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer
parameters through the optimizer at a time or switch this flag to False (default: None)"""

_fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
are supported. (default: None)

.. note:: The foreach and fused implementations are typically faster than the for-loop,
single-tensor implementation. Thus, if the user has not specified BOTH flags
(i.e., when foreach = fused = None), we will attempt defaulting to the foreach
implementation when the tensors are all on CUDA. For example, if the user specifies
True for fused but nothing for foreach, we will run the fused implementation. If
the user specifies False for foreach but nothing for fused (or False for fused but
nothing for foreach), we will run the for-loop implementation. If the user specifies
True for both foreach and fused, we will prioritize fused over foreach, as it is
typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
we want to give it sufficient bake-in time, so we default to foreach and NOT
fused when the user has not specified either flag."""

_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
capture in a CUDA graph. Passing True can impair ungraphed performance,
so if you don't intend to graph capture this instance, leave it False
(default: False)"""

_differentiable_doc = r"""differentiable (bool, optional): whether autograd should
occur through the optimizer step in training. Otherwise, the step()
function runs in a torch.no_grad() context. Setting to True can impair
performance, so leave it False if you don't intend to run autograd
through this instance (default: False)"""

_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the
params, instead of minimizing (default: False)"""


def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle:
r"""Register a pre hook common to all optimizers. The hook should have the following
signature::

hook(optimizer, args, kwargs) -> None or modified args and kwargs

Args:
hook (Callable): A user defined hook which is registered on all optimizers.

Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
_global_optimizer_pre_hooks[handle.id] = hook
return handle


def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle:
r"""Register a post hook common to all optimizers. The hook should have the following
signature::

hook(optimizer, args, kwargs) -> None

Args:
hook (Callable): A user defined hook which is registered on all optimizers.

Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
_global_optimizer_post_hooks[handle.id] = hook
return handle

ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]

_P = ParamSpec("_P")
R = TypeVar("R")
T = TypeVar("T")


class Optimizer:
r"""Base class for all optimizers.

.. warning::
Parameters need to be specified as collections that have a deterministic
ordering that is consistent between runs. Examples of objects that don't
satisfy those properties are sets and iterators over values of dictionaries.

def _is_inner_optimizer(self):
return True
Args:
params (iterable): an iterable of :class:`torch.Tensor` s or
:class:`dict` s. Specifies what Tensors should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
"""

OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]

_optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'

def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
self.defaults = defaults
self._optimizer_step_pre_hooks = OrderedDict()
self._optimizer_step_post_hooks = OrderedDict()
self._optimizer_state_dict_pre_hooks = OrderedDict()
self._optimizer_state_dict_post_hooks = OrderedDict()
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
self._optimizer_load_state_dict_post_hooks = OrderedDict()


if isinstance(params, torch.Tensor):
if self.__class__.__name__ == 'SparseAdam':
warnings.warn(("Passing in a raw Tensor as ``params`` to SparseAdam "
"is deprecated. In the future, this will raise an error. "
"Please wrap your Tensor in an iterable instead."),
FutureWarning)
params = [params]
else:
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))

self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: List[Dict[str, Any]] = []

param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]

for param_group in param_groups:
self.add_param_group(cast(dict, param_group))

# Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
# which I don't think exists
# https://github.com/pytorch/pytorch/issues/72948
self._warned_capturable_if_run_uncaptured = True

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {
'defaults': self.defaults,
'state': self.state,
'param_groups': self.param_groups,
}
def __setstate__(self, state):

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
if '_optimizer_step_pre_hooks' not in self.__dict__:
self._optimizer_step_pre_hooks = OrderedDict()
if '_optimizer_step_post_hooks' not in self.__dict__:
self._optimizer_step_post_hooks = OrderedDict()
self._patch_step_function()
if '_optimizer_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_state_dict_pre_hooks = OrderedDict()
if '_optimizer_state_dict_post_hooks' not in self.__dict__:
self._optimizer_state_dict_post_hooks = OrderedDict()
if '_optimizer_load_state_dict_pre_hooks' not in self.__dict__:
self._optimizer_load_state_dict_pre_hooks = OrderedDict()
if '_optimizer_load_state_dict_post_hooks' not in self.__dict__:
self._optimizer_load_state_dict_post_hooks = OrderedDict()
self.defaults.setdefault('differentiable', False)

def __repr__(self):
def __repr__(self) -> str:
format_string = self.__class__.__name__ + ' ('
for i, group in enumerate(self.param_groups):
format_string += '\n'
format_string += 'Parameter Group {0}\n'.format(i)
format_string += f'Parameter Group {i}\n'
for key in sorted(group.keys()):
if key != 'params':
format_string += ' {0}: {1}\n'.format(key, group[key])
format_string += f' {key}: {group[key]}\n'
format_string += ')'
return format_string

def _optimizer_step_code(self) -> None:
"""Entry point for `torch.profile.profiler`.

When python tracing is enabled the profiler will hook into this
function at the CPython level to inspect the optimizer's parameters and
param groups. It is called it after `step()` since many optimizers
lazily initialize state.

This is a workaround due to lack of a proper step hook on the optimizer,
and will be removed if it exists.
"""
pass

@staticmethod
def profile_hook_step(func):
unsupported_attr(func)
raise NotImplementedError("For Optimizer, 'profile_hook_step' not support yet.")
def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]:

@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R:
self, *_ = args
self = cast(Optimizer, self)
profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
with torch.autograd.profiler.record_function(profile_name):
# call optimizer step pre hooks
for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
result = pre_hook(self, args, kwargs)
if result is not None:
if isinstance(result, tuple) and len(result) == 2:
args, kwargs = result # type: ignore[assignment]
else:
raise RuntimeError(
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
)

out = func(*args, **kwargs)
self._optimizer_step_code()

# call optimizer step post hooks
for post_hook in chain(self._optimizer_step_post_hooks.values(), _global_optimizer_post_hooks.values()):
post_hook(self, args, kwargs)

return out

return wrapper

def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle:
r"""Register an optimizer step pre hook which will be called before
optimizer step. It should have the following signature::

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The ``optimizer`` argument is the optimizer instance being used. If
args and kwargs are modified by the pre-hook, then the transformed
values are returned as a tuple containing the new_args and new_kwargs.

def _patch_step_function(self):
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
# hook not support yet.
# hooked = getattr(self.__class__.step, "hooked", None)
# if not hooked:
# self.__class__.step = self.profile_hook_step(self.__class__.step)
# self.__class__.step.hooked = True
Args:
hook (Callable): The user defined hook to be registered.

def register_step_pre_hook(self):
raise NotImplementedError("For optimizer, 'register_step_pre_hook' is not supported yet.")
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
self._optimizer_step_pre_hooks[handle.id] = hook
return handle

def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
r"""Register an optimizer step post hook which will be called after optimizer step.
It should have the following signature::

hook(optimizer, args, kwargs) -> None

The ``optimizer`` argument is the optimizer instance being used.

def register_step_post_hook(self):
raise NotImplementedError("For optimizer, 'register_step_post_hook' is not supported yet.")
Args:
hook (Callable): The user defined hook to be registered.

Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
self._optimizer_step_post_hooks[handle.id] = hook
return handle

def state_dict(self):

def register_state_dict_pre_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a state dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.
The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
The registered hook can be used to perform pre-processing before the ``state_dict``
call is made.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the already registered pre-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the already registered
pre-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
self._optimizer_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_state_dict_post_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a state dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments ``self`` and ``state_dict`` after generating
a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
return a new one. The registered hook can be used to perform post-processing
on the ``state_dict`` before it is returned.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the already registered post-hooks on ``state_dict``. Otherwise,
the provided ``hook`` will be fired after all the already registered
post-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
self._optimizer_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
return handle

def state_dict(self) -> StateDict:
r"""Returns the state of the optimizer as a :class:`dict`.

It contains two entries:

* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
* ``state``: a Dict holding current optimization state. Its content
differs between optimizer classes, but some common characteristics
hold. For example, state is saved per parameter, and the parameter
itself is NOT saved. ``state`` is a Dictionary mapping parameter ids
to a Dict with state corresponding to each parameter.
* ``param_groups``: a List containing all parameter groups where each
parameter group is a Dict. Each parameter group contains metadata
specific to the optimizer, such as learning rate and weight decay,
as well as a List of parameter IDs of the parameters in the group.

NOTE: The parameter IDs may look like indices but they are just IDs
associating state with param_group. When loading from a state_dict,
the optimizer will zip the param_group ``params`` (int IDs) and the
optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to
match state WITHOUT additional verification.

A returned state dict might look something like:

.. code-block:: text

{
'state': {
0: {'momentum_buffer': tensor(...), ...},
1: {'momentum_buffer': tensor(...), ...},
2: {'momentum_buffer': tensor(...), ...},
3: {'momentum_buffer': tensor(...), ...}
},
'param_groups': [
{
'lr': 0.01,
'weight_decay': 0,
...
'params': [0]
},
{
'lr': 0.001,
'weight_decay': 0.5,
...
'params': [1, 2, 3]
}
]
}

"""

for pre_hook in self._optimizer_state_dict_pre_hooks.values():
pre_hook(self)

# Save order indices instead of Tensors
param_mappings = {}
param_mappings: Dict[int, int] = {}
start_index = 0

def pack_group(group):
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
if 'lr' in packed.keys():
if isinstance(packed['lr'], ms.Tensor):
packed['lr'] = packed['lr'].asnumpy().tolist()
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
@@ -97,25 +497,149 @@ class _Optimizer:
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, Tensor) else k): v
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
return {

state_dict = {
'state': packed_state,
'param_groups': param_groups,
}

def load_state_dict(self, state_dict):
for post_hook in self._optimizer_state_dict_post_hooks.values():
hook_result = post_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result
return state_dict

@staticmethod
def _process_value_according_to_param_policy(
param: torch.Tensor,
value: torch.Tensor,
param_id: int,
param_groups: List[Dict[Any, Any]],
key: Hashable = None,
) -> torch.Tensor:
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
# UNLESS fused or capturable, see note [special device hosting for step]
fused = False
capturable = False
assert param_groups is not None
for pg in param_groups:
if param_id in pg["params"]:
fused = pg["fused"] if "fused" in pg else False
capturable = pg["capturable"] if "capturable" in pg else False
break

if key == 'step':
if capturable or fused:
return value.to(dtype=torch.float32, device=param.device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=param.device)
else:
return value.to(device=param.device)


def register_load_state_dict_pre_hook(
self,
hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
prepend: bool = False,
) -> RemovableHandle:
r"""Register a load_state_dict pre-hook which will be called before
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer, state_dict) -> state_dict or None

The ``optimizer`` argument is the optimizer instance being used and the
``state_dict`` argument is a shallow copy of the ``state_dict`` the user
passed in to ``load_state_dict``. The hook may modify the state_dict inplace
or optionally return a new one. If a state_dict is returned, it will be used
to be loaded into the optimizer.

The hook will be called with argument ``self`` and ``state_dict`` before
calling ``load_state_dict`` on ``self``. The registered hook can be used to
perform pre-processing before the ``load_state_dict`` call is made.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided pre ``hook`` will be fired before
all the already registered pre-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the already registered
pre-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
return handle


def register_load_state_dict_post_hook(
self, hook: Callable[["Optimizer"], None], prepend: bool = False
) -> RemovableHandle:
r"""Register a load_state_dict post-hook which will be called after
:meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
following signature::

hook(optimizer) -> None

The ``optimizer`` argument is the optimizer instance being used.

The hook will be called with argument ``self`` after calling
``load_state_dict`` on ``self``. The registered hook can be used to
perform post-processing after ``load_state_dict`` has loaded the
``state_dict``.

Args:
hook (Callable): The user defined hook to be registered.
prepend (bool): If True, the provided post ``hook`` will be fired before
all the already registered post-hooks on ``load_state_dict``. Otherwise,
the provided ``hook`` will be fired after all the already registered
post-hooks. (default: False)

Returns:
:class:`torch.utils.hooks.RemoveableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
return handle



def load_state_dict(self, state_dict: StateDict) -> None:
r"""Loads the optimizer state.

Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
if hook_result is not None:
state_dict = hook_result

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict['param_groups'])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
@@ -127,133 +651,156 @@ class _Optimizer:
"that doesn't match the size of optimizer's group")

# Update the state
id_map = dict(zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups))))
id_map = dict(zip(chain.from_iterable(g['params'] for g in saved_groups),
chain.from_iterable(g['params'] for g in groups)))

def cast(param, value, key=None):
def _cast(param, value, param_id=None, param_groups=None, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
if isinstance(value, torch.Tensor):
return Optimizer._process_value_according_to_param_policy(param, value, param_id, param_groups, key)
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
return {k: _cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(cast(param, v) for v in value)
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
else:
return value

# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
state[param] = _cast(param, v, param_id=k, param_groups=state_dict['param_groups'])
else:
state[k] = v

# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
def update_group(group: Dict[str, Any], new_group: Dict[str, Any]) -> Dict[str, Any]:
new_group['params'] = group['params']
if 'lr' in group.keys():
if isinstance(group['lr'], ms.Parameter):
new_group['lr'] = ms.Parameter(ms.Tensor(new_group['lr'], ms.float32), group['lr'].name)
return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})

def _ms_state_dict(self, *ms_params_name):
_state_dict = _Optimizer.state_dict(self)
def _save(ms_params):
if isinstance(ms_params, Iterable):
_state = []
for p in ms_params:
_state.append(_save(p))
else:
_state = tensor(ms_params.asnumpy())
return _state

for name in ms_params_name:
ms_params = getattr(self, name, None)
if ms_params is not None:
_state_dict[name] = _save(ms_params)
return _state_dict

def _ms_load_state_dict(self, state_dict, *ms_params_name):
_Optimizer.load_state_dict(self, state_dict)

def _load(ms_params, state_tensor, name):
if isinstance(ms_params, Iterable):
if not isinstance(state_tensor, Iterable):
raise ValueError(f"state_dict of ms_param '{name}' is not correct. please check. "
f"(ms_param '{name}' is Iterable, but state_dict['{name}'] is not.)")
if len(ms_params) != len(state_tensor):
raise ValueError(f"state_dict of ms_param '{name}' is not correct. please check. "
f"(length of ms_param '{name}' and state_dict['{name}'] are not equal, "
f"get {len(ms_params)} and {len(state_tensor)}")
for i, _ in enumerate(ms_params):
_load(ms_params[i], state_tensor[i], name)
else:
_data = cast_to_ms_tensor(state_tensor)
try:
ms_params.set_data(_data)
except Exception as e:
raise ValueError(f"state_dict of ms_param '{name}' is not correct. please check. "
f"({e})") from e

for name in ms_params_name:
ms_params = getattr(self, name, None)
if ms_params is None:
continue
_params = state_dict.get(name, None)
# If name in state_dict, use state_dict[name], because it was saved from MindTorch.
if _params is not None:
_load(ms_params, _params, name)
else:
_state = state_dict.get('state', None)
# If name in state_dict['state'], it was saved from PyTorch. Load that to MindTorch.
if _state is not None:
# _state is a dict like: {0:{name: Tensor}, 1:{name:Tensor}}
for k, state in _state.items():
_params = state.get(name, None)
# assert name in state.
if _params is not None:
_load(ms_params[k], _params, name)

def step(self, grads, closure=None):
loss = None
if closure is not None:
loss = closure()
self.construct(grads)
return loss

class _OptimizerMeta(abc.ABCMeta, type(Optimizer_MS)):
"""
Meta class for Optimizer. Used internally.
"""
for post_hook in self._optimizer_load_state_dict_post_hooks.values():
post_hook(self)


class Optimizer(_Optimizer, Optimizer_MS, metaclass=_OptimizerMeta):
def __init__(self, *args, **kwargs):
Optimizer_MS.__init__(self, *args, **kwargs)
_Optimizer.__init__(self)

@classmethod
def __subclasshook__(cls, sub):
def zero_grad(self, set_to_none: bool = True) -> None:
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.

Args:
set_to_none (bool): instead of setting to zero, set the grads to None.
This will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
Subclass with _is_inner_optimizer attr will be instance of Optimizer
foreach = self.defaults.get('foreach', False) or self.defaults.get('fused', False)

per_device_and_dtype_grads: Optional[DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]]
if foreach:
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
else:
per_device_and_dtype_grads = None

for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
if (not foreach or p.grad.is_sparse):
p.grad.zero_()
else:
assert per_device_and_dtype_grads is not None
per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
if foreach:
assert per_device_and_dtype_grads is not None
for per_dtype_grads in per_device_and_dtype_grads.values():
for grads in per_dtype_grads.values():
torch._foreach_zero_(grads)

@overload
def step(self, closure: None = ...) -> None:
...

@overload
def step(self, closure: Callable[[], float]) -> float:
...

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
r"""Performs a single optimization step (parameter update).

Args:
closure (Callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.

.. note::
Unless otherwise specified, this function should not modify the
``.grad`` field of the parameters.
"""
raise NotImplementedError


def add_param_group(self, param_group: Dict[str, Any]) -> None:
r"""Add a param group to the :class:`Optimizer` s `param_groups`.

This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.

Args:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options.
"""
if cls is Optimizer:
if any("_is_inner_optimizer" in s.__dict__ for s in sub.__mro__):
return True
return NotImplemented
if not isinstance(param_group, dict):
raise TypeError(f"param_group must be a dict, but got {type(param_group)}")

params = param_group['params']
if isinstance(params, torch.Tensor):
param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
param_group['params'] = list(params)

for param in param_group['params']:
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param))
if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad):
raise ValueError("can't optimize a non-leaf Tensor")

for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError(f"parameter group didn't specify a value of required optimization parameter {name}")
else:
param_group.setdefault(name, default)

params = param_group['params']
if len(params) != len(set(params)):
warnings.warn("optimizer contains a parameter group with duplicate parameters; "
"in future, this will cause an error; "
"see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)

param_set: Set[torch.Tensor] = set()
for group in self.param_groups:
param_set.update(set(group['params']))

if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError("some parameters appear in more than one parameter group")

def _is_tensor(obj):
return isinstance(obj, Tensor)
self.param_groups.append(param_group)

+ 257
- 25
mindtorch/torch/optim/sgd.py View File

@@ -1,32 +1,264 @@
from mindspore.experimental.optim import SGD as SGD_MS
from mindtorch.torch.optim.optimizer import _Optimizer
from mindtorch.utils import unsupported_attr

_default_lr = 0.01
class SGD(_Optimizer, SGD_MS):
def __init__(self, params, lr=None, momentum=0, dampening=0,
weight_decay=0, nesterov=False, *, maximize=False, foreach=None,
differentiable=False):
unsupported_attr(foreach)
unsupported_attr(differentiable)
if lr is None:
for p_dict in params:
if not isinstance(p_dict, dict) or 'lr' not in p_dict:
raise ValueError("parameter group didn't specify a value of required optimization parameter lr.")
# Fake lr. The above code guarantees that every param_group has its own 'lr' setting.
# So the following _default_lr won't take effect, just for the input args of mindspore SGD.
lr = _default_lr
SGD_MS.__init__(self, params, lr, momentum, dampening, weight_decay, nesterov, maximize=maximize)
_Optimizer.__init__(self)
from mindtorch import torch
from mindtorch.torch import Tensor
from .optimizer import (Optimizer, _use_grad_for_differentiable,
_differentiable_doc, _foreach_doc, _maximize_doc)
from typing import List, Optional

__all__ = ['SGD', 'sgd']


class SGD(Optimizer):
def __init__(self, params, lr=1e-3, momentum=0, dampening=0,
weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
differentiable: bool = False):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
maximize=maximize, foreach=foreach,
differentiable=differentiable)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)

def __setstate__(self, state):
_Optimizer.__setstate__(self, state)
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
group.setdefault('maximize', False)
group.setdefault('foreach', None)
group.setdefault('differentiable', False)

def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
has_sparse_grad = False

for p in group['params']:
# print(p.grad, p.requires_grad)
if p.grad is not None:
params_with_grad.append(p)
d_p_list.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True

state = self.state[p]
if 'momentum_buffer' not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state['momentum_buffer'])

return has_sparse_grad

@_use_grad_for_differentiable
def step(self, closure=None):
"""Performs a single optimization step.

Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []

has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)

sgd(params_with_grad,
d_p_list,
momentum_buffer_list,
weight_decay=group['weight_decay'],
momentum=group['momentum'],
lr=group['lr'],
dampening=group['dampening'],
nesterov=group['nesterov'],
maximize=group['maximize'],
has_sparse_grad=has_sparse_grad,
foreach=group['foreach'])

# update momentum_buffers in state
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
state = self.state[p]
state['momentum_buffer'] = momentum_buffer

return loss


SGD.__doc__ = r"""Implements stochastic gradient descent (optionally with momentum).

.. math::
\begin{aligned}
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
&\hspace{10mm}\textbf{if} \: t > 1 \\
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
&\hspace{10mm}\textbf{else} \\
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
&\hspace{10mm}\textbf{else} \\[-1.ex]
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
&\hspace{5mm}\textbf{else} \\[-1.ex]
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
&\bf{return} \: \theta_t \\[-1.ex]
&\rule{110mm}{0.4pt} \\[-1.ex]
\end{aligned}

Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
""" + fr"""
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
{_maximize_doc}
{_foreach_doc}
{_differentiable_doc}
""" + r"""

Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf

.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.

Considering the specific case of Momentum, the update can be written as

.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}

where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.

This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form

.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} - v_{t+1}.
\end{aligned}

The Nesterov version is analogously modified.

Moreover, the initial value of the momentum buffer is set to the
gradient value at the first step. This is in contrast to some other
frameworks that initialize it to all zeros.

"""


def sgd(params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
has_sparse_grad: bool = None,
foreach: Optional[bool] = None,
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool):
r"""Functional API that performs SGD algorithm computation.

See :class:`~torch.optim.SGD` for details.
"""

# if foreach is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
# if not torch.jit.is_scripting():
# _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
# else:
# foreach = False

# if foreach and torch.jit.is_scripting():
# raise RuntimeError('torch.jit.script not supported with foreach optimizers')

# if foreach and not torch.jit.is_scripting():
# func = _multi_tensor_sgd
# else:
func = _single_tensor_sgd

func(params,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
has_sparse_grad=has_sparse_grad,
maximize=maximize)

def _single_tensor_sgd(params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool,
has_sparse_grad: bool):

for i, param in enumerate(params):
d_p = d_p_list[i] if not maximize else -d_p_list[i]

if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)

if momentum != 0:
buf = momentum_buffer_list[i]

if buf is None:
buf = torch.clone(d_p).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

def state_dict(self):
return super()._ms_state_dict('accum')
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf

def load_state_dict(self, state_dict):
return super()._ms_load_state_dict(state_dict, 'accum')
param.add_(d_p, alpha=-lr)

+ 148
- 176
mindtorch/torch/tensor.py View File

@@ -8,6 +8,7 @@ from copy import deepcopy
from functools import reduce
import numpy as np
import mindspore as ms
from mindspore import ops
from mindspore import Tensor as ms_Tensor
from mindspore.scipy.ops import SolveTriangular
from mindspore.common import dtype as mstype
@@ -18,6 +19,7 @@ from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.initializer import Zero
from mindspore._c_expression import Tensor as Tensor_
from mindspore.common._stub_tensor import StubTensor
from mindspore.ops.composite.multitype_ops._compile_utils import _tensor_getitem

from mindtorch.utils import unsupported_attr, is_under_gpu_context, get_backend, is_under_ascend_context, _infer_size, \
_ascend_tensor_general_cast, is_under_cpu_context, pynative_mode_condition, set_multiple_name_tuple, \
@@ -246,8 +248,8 @@ class _TensorMeta(type(ms_Tensor), abc.ABCMeta):
Meta class for Tensor. Used internally.
"""

class Tensor(StubTensor, metaclass=_TensorMeta):
def __init__(self, *data, dtype=None, inner=False, cast_tensor=False):
class Tensor(StubTensor):
def __init__(self, *data, dtype=None, requires_grad=False, inner=False, cast_tensor=False):
if cast_tensor:
if len(data) != 1:
raise RuntimeError("Tensor init data lenght is not 1 when cast_tensor=True")
@@ -255,9 +257,15 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
if isinstance(input_data, StubTensor):
self.stub = input_data.stub
self.tensor = input_data.tensor
self.requires_grad_ = input_data.requires_grad_ or requires_grad
self.grad_fn_ = input_data.grad_fn_
self.grad_ = input_data.grad_
elif isinstance(input_data, Tensor_):
self.stub = None
self.tensor = input_data
self.requires_grad = requires_grad
self.grad_fn_ = None
self.grad_ = None
else:
raise ValueError(f"Tensor init data type is invaild: {type(input_data)}")
self.adapter_flag = True
@@ -284,7 +292,40 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
init_tensor = ms_Tensor(input_data=_input_data, dtype=dtype)
super(Tensor, self).__init__(tensor=init_tensor)
self.adapter_flag = True
self.requires_grad = requires_grad

def backward(self, grad=None):
r"""
calculate the gradient.
"""
# assert self.shape == ()
if grad is None:
grad = self.new_ones(self.shape)

super().backward(grad)

@property
def grad(self):
r"""
get grad value.
"""
if self.grad_fn_ is not None:
self.grad_fn_.get_grad()
return self.grad_

@grad.setter
def grad(self, grad):
r"""
set grad value.
"""
if grad is None:
self.grad_ = grad
return

if self.grad_ is None:
self.grad_ = Tensor(grad)
else:
self.grad_ = self.grad_ + Tensor(grad, cast_tensor=True)

def _process_data(self, data):
_shape = None
@@ -369,10 +410,11 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(out)

def __add__(self, other):
tensor_ms = cast_to_ms_tensor(self)
other_ms = cast_to_ms_tensor(other)
# tensor_ms = cast_to_ms_tensor(self)
# other_ms = cast_to_ms_tensor(other)
# TODO: mindspore __add__ do not support logical_or with two bool dtype tensors.
out = tensor_ms.__add__(other_ms)
# out = tensor_ms.__add__(other_ms)
out = ops.add(self, other)
return cast_to_adapter_tensor(out)

def __and__(self, other):
@@ -414,9 +456,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return self.__add__(other)

def __sub__(self, other):
tensor_ms = cast_to_ms_tensor(self)
other_ms = cast_to_ms_tensor(other)
out = tensor_ms.__sub__(other_ms)
out = ops.sub(self, other)
return cast_to_adapter_tensor(out)

def __rsub__(self, other):
@@ -433,9 +473,10 @@ class Tensor(StubTensor, metaclass=_TensorMeta):

def __mul__(self, other):
# TODO: In mindspore tensor.__mul__, float tensor can not mul with complex tensor
tensor_ms = cast_to_ms_tensor(self)
other_ms = cast_to_ms_tensor(other)
out = tensor_ms.__mul__(other_ms)
# tensor_ms = cast_to_ms_tensor(self)
# other_ms = cast_to_ms_tensor(other)
# out = tensor_ms.__mul__(other_ms)
out = ops.mul(self, other)
return cast_to_adapter_tensor(out)

def __rmul__(self, other):
@@ -535,15 +576,14 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(out)

def __eq__(self, other):
tensor_ms = cast_to_ms_tensor(self)
other_ms = cast_to_ms_tensor(other)
out = tensor_ms.__eq__(other_ms)
out = ops.eq(self, other)
return cast_to_adapter_tensor(out)

def __matmul__(self, other):
tensor_ms = cast_to_ms_tensor(self)
other_ms = cast_to_ms_tensor(other)
out = tensor_ms.__matmul__(other_ms)
# tensor_ms = cast_to_ms_tensor(self)
# other_ms = cast_to_ms_tensor(other)
# out = tensor_ms.__matmul__(other_ms)
out = ops.matmul(self, other)
return cast_to_adapter_tensor(out)

def __rmatmul__(self, other):
@@ -564,29 +604,30 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
# __setitem__ no need to overload
def __getitem__(self, index):
# TODO: not support complex Tensor and False bool index getitem
def _getitem_handler(tensor_ms, index):
if isinstance(index, bool):
if index:
return tensor_ms.expand_dims(0)
else:
index = ms.Tensor(False)
out = ms.ops.masked_select(tensor_ms, index)
return out
if isinstance(index, tuple) and isinstance(index[0], bool):
if False in index:
index = ms.Tensor(False)
out = ms.ops.masked_select(tensor_ms, index)
return out
else:
return tensor_ms.expand_dims(0)
return tensor_ms.__getitem__(index)

tensor_ms = cast_to_ms_tensor(self)
out_ms = _getitem_handler(tensor_ms, index)
# def _getitem_handler(tensor_ms, index):
# if isinstance(index, bool):
# if index:
# return tensor_ms.expand_dims(0)
# else:
# index = ms.Tensor(False)
# out = ms.ops.masked_select(tensor_ms, index)
# return out
# if isinstance(index, tuple) and isinstance(index[0], bool):
# if False in index:
# index = ms.Tensor(False)
# out = ms.ops.masked_select(tensor_ms, index)
# return out
# else:
# return tensor_ms.expand_dims(0)
# return tensor_ms.__getitem__(index)

# tensor_ms = cast_to_ms_tensor(self)
# out_ms = _getitem_handler(tensor_ms, index)
out_ms = _tensor_getitem(self, index)
out = cast_to_adapter_tensor(out_ms)
if out_ms is not tensor_ms:
out.parent_tensor_ = tensor_ms
out.index_of_parent_ = index
# if out_ms is not tensor_ms:
# out.parent_tensor_ = tensor_ms
# out.index_of_parent_ = index
return out

def __getstate__(self):
@@ -612,10 +653,14 @@ class Tensor(StubTensor, metaclass=_TensorMeta):

@property
def dtype(self):
x = cast_to_ms_tensor(self)
dtype = x.dtype
dtype = super().dtype
return _msdtype2typeDict.get(str(dtype), dtype)

@property
def nbytes(self):
"""nbytes stub."""
return self.numel() * self.itemsize

def fill_adapter(self, val):
val = cast_to_ms_tensor(val)
output = ms.ops.fill(self.dtype, self.shape, val)
@@ -751,17 +796,18 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return _tensor_inplace_assign(self, output, "erfinv_", "erfinv")

def permute(self, *dims):
ms_input = cast_to_ms_tensor(self)
if isinstance(dims, list):
dims = tuple(dims)
output = ms_input.transpose(*dims)
if isinstance(dims[0], (tuple, list)):
dims = tuple(dims[0])
output = ops.transpose(self, dims)
return cast_to_adapter_tensor(output)

def contiguous(self, memory_format=None):
unsupported_attr(memory_format)
ms_input = cast_to_ms_tensor(self)
output = ms_input.contiguous()
return cast_to_adapter_tensor(output)
# stub = self.stub_sync()
# Tensor_.contiguous(stub)
return self

def new_tensor(self, data, *, dtype=None, device=None, requires_grad=False, layout=None, pin_memory=False):
unsupported_attr(layout)
@@ -1042,32 +1088,19 @@ class Tensor(StubTensor, metaclass=_TensorMeta):


def numel(self):
input_ms = cast_to_ms_tensor(self)
return P.size(input_ms)
return ops.size(self)

def detach(self):
input_ms = cast_to_ms_tensor(self)
output = ms.ops.stop_gradient(input_ms)
output = ops.stop_gradient(self)
return cast_to_adapter_tensor(output)

def detach_(self):
return _tensor_inplace_assign(self, self.detach(), "detach_", "detach")

def sum(self, dim=None, keepdim=False, dtype=None):
input_ms = cast_to_ms_tensor(self)
# TODO: mindspore tensor.sum can not automatically promote dtype yet, will cause overflow.
if dtype is not None:
input_ms = input_ms.astype(dtype) if dtype != mstype.bool_ else \
input_ms.astype(mstype.bool_).astype(mstype.int64)
elif input_ms.dtype in msdapter_dtype.all_int_type_with_bool:
dtype = mstype.int64
input_ms = input_ms.astype(dtype)

if isinstance(dim, list):
dim = tuple(dim)
res = input_ms.sum(dim, dtype, keepdim)
if dtype is not None and dtype == mstype.bool_:
res = res.astype(mstype.bool_)
if dim is None:
dim = ()
res = ops.sum(self, dim, keepdim, dtype=dtype)
return cast_to_adapter_tensor(res)

def sum_to_size(self, *size):
@@ -1079,11 +1112,11 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
if dim is None:
dim = axis

input_ms = cast_to_adapter_tensor(self)
if dtype:
input_ms = self.astype(dtype)
# input_ms = cast_to_adapter_tensor(self)
# if dtype:
# input_ms = self.astype(dtype)

output = ms.ops.mean(input_ms, dim, keepdim)
output = ops.mean(self, dim, keepdim)
return cast_to_adapter_tensor(output)

def prod(self, dim=None, keepdim=False, dtype=None):
@@ -1198,14 +1231,15 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
shape = tuple(shape)

input_size = self.shape
if not ms.ops.is_sequence_value_unknown(input_size) and len(input_size) > 0 and input_size[0] == 0:
# only support first element is 0
numel = ms.ops.size(self)
shape = _infer_size(shape, numel)
output = ms.ops.zeros(shape, self.dtype)
else:
input_ms = cast_to_ms_tensor(self)
output = ms.ops.reshape(input_ms, shape)
# if not ms.ops.is_sequence_value_unknown(input_size) and len(input_size) > 0 and input_size[0] == 0:
# # only support first element is 0
# numel = ms.ops.size(self)
# shape = _infer_size(shape, numel)
# output = ms.ops.zeros(shape, self.dtype)
# else:
output = ops.reshape(self, shape)

out = cast_to_adapter_tensor(output)
return cast_to_adapter_tensor(output)

def reshape_as(self, other):
@@ -1313,11 +1347,10 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(output)

def item(self):
input_ms = cast_to_ms_tensor(self)
if input_ms.size > 1:
if self.numel() > 1:
raise ValueError("only one element tensors can be converted to Python scalars")
output = input_ms.asnumpy().reshape(-1).tolist()
return output[0]
output = self.numpy().item()
return output

def log(self):
input_ms = cast_to_ms_tensor(self)
@@ -1408,8 +1441,9 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(input_ms.unbind(dim))

def unsqueeze(self, dim):
input_ms = cast_to_ms_tensor(self)
return cast_to_adapter_tensor(input_ms.unsqueeze(dim))
# input_ms = cast_to_ms_tensor(self)
# return cast_to_adapter_tensor(input_ms.unsqueeze(dim))
return ops.unsqueeze(self, dim)

def unsqueeze_(self, dim):
output = self.unsqueeze(dim)
@@ -1420,9 +1454,8 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return input_ms.is_signed()

def transpose(self, dim0, dim1):
input_ms = cast_to_ms_tensor(self)
# The functions of ms.ops.swapaxes are consistent with torch.transpose
output = ms.ops.swapaxes(input_ms, dim0, dim1)
output = ops.swapaxes(self, dim0, dim1)
return cast_to_adapter_tensor(output)

def transpose_(self, dim0, dim1):
@@ -1693,33 +1726,6 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
def is_quantized(self, flag):
raise AttributeError("attribute 'is_quantized' of 'torch.Tensor' objects is not writable.")

@property
def requires_grad(self):
warning("tensor.requires_grad only suppport set to True now. So It is always True.")
return True

@requires_grad.setter
def requires_grad(self, flag):
if not isinstance(flag, bool):
raise RuntimeError("requires_grad must be a bool")
if flag is False:
raise NotImplementedError("tensor.requires_grad can not set to False yet. "
"If tensor is not leaf Tensor, can try tensor.detach() instead. "
"If tensor is leaf Tensor, can replaces tensor with Parameter, because "
"Parameter.requires_grad work with mindspore autograd mechanism, "
"when it set to False, the gradient return by ms.grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/"
"api_python/mindspore/mindspore.grad.html) "
"or ms.value_and_grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/"
"api_python/mindspore/mindspore.value_and_grad.html)"
" is zero. ")

def requires_grad_(self, requires_grad=True):
if requires_grad is False:
warning("requires_grad is always True in Tensor.")
return self

def nonzero(self, *, out=None, as_tuple=False):
if out is not None:
warning("Do not support parameter 'out'.")
@@ -1744,8 +1750,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):

def bool(self, memory_format=None):
unsupported_attr(memory_format)
input_ms = cast_to_ms_tensor(self)
output = input_ms.bool()
output = ops.Cast()(self, ms.bool_)
return cast_to_adapter_tensor(output)

def eq(self, other):
@@ -1776,8 +1781,8 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return _tensor_inplace_assign(self, output, "exp_", "exp")

def masked_fill(self, mask, value):
input_ms = cast_to_ms_tensor(self)
output = input_ms.masked_fill(mask.bool(), value)
# input_ms = cast_to_ms_tensor(self)
output = ops.masked_fill(self, mask.bool(), value)
return cast_to_adapter_tensor(output)

def masked_fill_(self, mask, value):
@@ -1889,8 +1894,8 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(output)

def flatten(self, start_dim=0, end_dim=-1):
input_ms = cast_to_ms_tensor(self)
output = ms.ops.flatten(input_ms, order='C', start_dim=start_dim, end_dim=end_dim)
# input_ms = cast_to_ms_tensor(self)
output = ops.flatten(self, order='C', start_dim=start_dim, end_dim=end_dim)
return cast_to_adapter_tensor(output)

def unflatten(self, dim, sizes):
@@ -2116,8 +2121,8 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
def argmax(self, dim=None, keepdim=False, axis=None):
if dim is None:
dim = axis
input_ms = cast_to_ms_tensor(self)
output = ms.ops.argmax(input_ms, dim, keepdim)
# input_ms = cast_to_ms_tensor(self)
output = ops.argmax(self, dim, keepdim)
return cast_to_adapter_tensor(output)

def type(self, dtype=None, non_blocking=False, **kwargs):
@@ -2129,13 +2134,13 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
_dtype = _get_dtype_from_type(dtype)
if _dtype == self.dtype:
return self
x = cast_to_ms_tensor(self)
output = x.astype(_dtype)
# x = cast_to_ms_tensor(self)
output = self.astype(_dtype)
return cast_to_adapter_tensor(output)

def type_as(self, tensor):
x = cast_to_ms_tensor(self)
output = x.astype(tensor.dtype)
# x = cast_to_ms_tensor(self)
output = self.astype(tensor.dtype)
return cast_to_adapter_tensor(output)

def get_device(self):
@@ -2626,9 +2631,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(output)

def nelement(self):
input_ms = cast_to_ms_tensor(self)
output = input_ms.nelement()
return output
return self.numel()

def aminmax(self, *, dim=None, keepdim=False):
_input = cast_to_ms_tensor(self)
@@ -2786,8 +2789,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return cast_to_adapter_tensor(output)

def is_complex(self):
input_ms = cast_to_ms_tensor(self)
return input_ms.is_complex()
return ops.is_complex(self)

def isinf(self):
input_ms = cast_to_ms_tensor(self)
@@ -4237,33 +4239,6 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
return rlt
return cast_to_adapter_tensor(value), cast_to_adapter_tensor(indices)

def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
unsupported_attr(gradient)
unsupported_attr(retain_graph)
unsupported_attr(create_graph)
unsupported_attr(inputs)
raise NotImplementedError(
"tensor.backward() not support yet. please use "
"mindspore.value_and_grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.value_and_grad.html) "
"or mindspore.grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.grad.html) "
"to compute gradient and send the gradient to the optimizer. "
"please refer to mobilenet_v2 example: "
"https://openi.pcl.ac.cn/OpenI/MindTorchModelZoo/src/branch/master/official/cv/"
"mobilenet_v2/mobilenet_v2_adapter.py")

@property
def grad(self):
raise NotImplementedError(
"tensor.grad not support yet. pleause use "
"mindspore.value_and_grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.value_and_grad.html) "
"or mindspore.grad"
"(https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore/mindspore.grad.html) "
"to get the gradient. And take out the corresponding element as grad."
)

def frexp(self):
# TODO: to use ms.ops.frexp
input_ms = cast_to_ms_tensor(self)
@@ -4406,34 +4381,30 @@ def _get_default_dtype_by_data(data):

def tensor(data, dtype=None, device=None, requires_grad=True):
unsupported_attr(device)
if requires_grad is False:
msg = ("In MindTorch, Tensor's `requires_grad` is always 'True', can not be set to 'False'. ")
warning(msg)

if dtype is None and _not_default_fp32_dtype():
dtype = _get_default_dtype_by_data(data)

return Tensor(data, dtype=dtype, inner=True)
return Tensor(data, dtype=dtype, requires_grad=requires_grad, inner=True)

def cast_to_ms_tensor(inputs):
"""
Cast MindTorch.Tensor to MindSpore.Tensor before call mindspore API.
"""
if isinstance(inputs, Tensor):
inputs = inner.convert_to_ms_tensor(inputs)
elif isinstance(inputs, tuple):
inputs_tuple = ()
for value in inputs:
inputs_tuple += (cast_to_ms_tensor(value), )
inputs = inputs_tuple
elif isinstance(inputs, list):
inputs_list = []
for value in inputs:
inputs_list.append(cast_to_ms_tensor(value))
inputs = inputs_list
elif isinstance(inputs, dict):
for key, value in inputs.items():
inputs[key] = cast_to_ms_tensor(value)
# if isinstance(inputs, Tensor):
# inputs = inner.convert_to_ms_tensor(inputs)
# elif isinstance(inputs, tuple):
# inputs_tuple = ()
# for value in inputs:
# inputs_tuple += (cast_to_ms_tensor(value), )
# inputs = inputs_tuple
# elif isinstance(inputs, list):
# inputs_list = []
# for value in inputs:
# inputs_list.append(cast_to_ms_tensor(value))
# inputs = inputs_list
# elif isinstance(inputs, dict):
# for key, value in inputs.items():
# inputs[key] = cast_to_ms_tensor(value)
return inputs


@@ -4442,7 +4413,8 @@ def cast_to_adapter_tensor(outputs):
Cast MindSpore.Tensor to MindTorch.Tensor after call mindspore API.
"""
if isinstance(outputs, (StubTensor, ms.Tensor)):
outputs = inner.convert_to_adapter_tensor(outputs)
# outputs = inner.convert_to_adapter_tensor(outputs)
outputs = Tensor(outputs, cast_tensor=True)
elif isinstance(outputs, tuple):
outputs_tuple = ()
for value in outputs:


+ 0
- 1
mindtorch/torch/utils/data/_utils/collate.py View File

@@ -333,7 +333,6 @@ def default_collate(batch):
def collate_nbytes(data):
all_nbytes = 0
if isinstance(data, torch.Tensor):
data = torch.cast_to_ms_tensor(data)
all_nbytes += data.nbytes
elif isinstance(data, np.ndarray):
all_nbytes += data.nbytes


+ 0
- 0
testing/__init__.py View File


+ 0
- 0
testing/st/__init__.py View File


+ 6
- 0
testing/st/mindtorch/__init__.py View File

@@ -0,0 +1,6 @@
from mindtorch import torch as mt
import torch as pt

class mtLinear(mt.nn.Module):
def __init__(self, auto_prefix=True, flags=None):
super().__init__(auto_prefix, flags)

+ 98
- 0
testing/st/mindtorch/test_simple_linear.py View File

@@ -0,0 +1,98 @@
from mindtorch import torch
from mindtorch.torch import nn
from mindtorch.torch.utils.data import DataLoader
from mindtorch.torchvision import datasets
from mindtorch.torchvision.transforms import ToTensor
from mindspore._c_expression import jit_mode_pi_enable, jit_mode_pi_disable

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break

class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)

def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits

model = NeuralNetwork()
print(model)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def run_train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation

loss.backward()
optimizer.step()
optimizer.zero_grad()
# break
# break
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

def run_test(dataloader, model, loss_fn):
print('run test')
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

def test_train_lenet():
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
run_train(train_dataloader, model, loss_fn, optimizer)
jit_mode_pi_disable()
run_test(test_dataloader, model, loss_fn)
jit_mode_pi_enable()
print("Done!")

+ 56
- 0
testing/ut/pytorch/autograd/test_backward.py View File

@@ -0,0 +1,56 @@
import numpy as np
import torch
from typing import Optional, Union
from mindtorch.torch import Tensor as mtTensor
from mindtorch.torch.nn import Parameter
from torch import Tensor as ptTensor

def run_backward(tensor: Union[mtTensor, ptTensor], grad_input=None):
if grad_input is None:
assert tensor.shape == ()
tensor.backward()
else:
assert tensor.shape == grad_input.shape
tensor.backward(grad_input)

def run_simple_op(a: Union[mtTensor, ptTensor], b: Union[mtTensor, ptTensor], op: str):
if op == '+':
return a + b
if op == '-':
return a + b
if op == '*':
return a + b
if op == '/':
return a + b
if op == '@':
return a @ b
raise ValueError(f'not support {op} yet')

def test_simple_op_backward_test():
a = np.random.randn(3, 3).astype(np.float32)
b = np.random.randn(3, 3).astype(np.float32)

pt_a, pt_b = torch.tensor(a, requires_grad=True), torch.tensor(b, requires_grad=True)
mt_a, mt_b = Parameter(mtTensor(a)), Parameter(mtTensor(b))

print(mt_a.requires_grad)

op_list = ['+', '-', '*', '/', '@']

for op in op_list:
pt_out = run_simple_op(pt_a, pt_b, op)
mt_out = run_simple_op(mt_a, mt_b, op)
print(mt_out.requires_grad)
print('pt_out.requires_grad', pt_out.requires_grad)

assert np.allclose(pt_out.detach().numpy(), mt_out.numpy(), 1e-4, 1e-4)

run_backward(pt_out, torch.tensor(np.ones((3, 3), np.float32)))
run_backward(mt_out, mtTensor(np.ones((3, 3), np.float32)))

# assert has grad
assert mt_a.grad is not None and mt_b.grad is not None
# allclose
assert np.allclose(pt_a.grad.detach().numpy(), mt_a.grad.numpy(), 1e-4, 1e-4)
assert np.allclose(pt_b.grad.detach().numpy(), mt_b.grad.numpy(), 1e-4, 1e-4)

+ 3
- 1
testing/ut/pytorch/nn/test_linear.py View File

@@ -9,6 +9,8 @@ import torch
from ...utils import set_mode_by_env_config, is_test_under_graph_context
set_mode_by_env_config()

import mindspore
mindspore.set_context(pynative_synchronize=True)

def test_linear_model():
class LinearModel(Module):
@@ -99,7 +101,7 @@ def test_bilinear_model():
def weight_init(m):
if isinstance(m, Bilinear):
m.weight.data = m.weight.data.normal_adapter(0, 0.01)
if m.has_bias:
if m.bias is not None:
m.bias.data = m.bias.data.zero_adapter()

model.apply(weight_init)


Loading…
Cancel
Save