#169 add support for einops third-party package

Merged
zoulq merged 3 commits from lhy into master 1 year ago
  1. +1
    -1
      ms_adapter/pytorch/common/__init__.py
  2. +2
    -0
      ms_adapter/pytorch/common/dtype.py
  3. +0
    -10
      ms_adapter/pytorch/conflict_functional.py
  4. +1
    -0
      ms_adapter/pytorch/cuda/__init__.py
  5. +6
    -1
      ms_adapter/pytorch/functional.py
  6. +53
    -4
      ms_adapter/pytorch/tensor.py
  7. +31
    -0
      testing/ut/pytorch/functional/test_cat.py
  8. +50
    -1
      testing/ut/pytorch/tensor/test_tensor.py
  9. +72
    -0
      third_party/einops/patch/0001-adapt-to-ms_adapter.pytorch.patch
  10. +10
    -0
      third_party/einops/setup.sh

+ 1
- 1
ms_adapter/pytorch/common/__init__.py View File

@@ -11,5 +11,5 @@ __all__ = ["float16", "float32",
"uint16", "uint32",
"uint64", "bool_",
"complex64", "complex128",
"long",
"long", "bfloat16"
]

+ 2
- 0
ms_adapter/pytorch/common/dtype.py View File

@@ -4,6 +4,8 @@
from mindspore import dtype as mstype

float16 = mstype.float16
# TODO: mindspore to support mstype.bfloat16
bfloat16 = mstype.float32
float32 = mstype.float32
float64 = mstype.float64
int8 = mstype.int8


+ 0
- 10
ms_adapter/pytorch/conflict_functional.py View File

@@ -32,16 +32,6 @@ def arange(start, end, step=1, *, out=None, dtype=None,
unsupported_attr(device)
unsupported_attr(requires_grad)

#TODO
# In mindspore 2.0, the code below can be delete
if dtype in (mstype.int64, mstype.float64):
if dtype == mstype.int64:
dtype_name = 'torch.int64'
else:
dtype_name = 'torch.float64'
raise ValueError('For now, `arange` only support `torch.float32` and `torch.int32`,'
'but got dtype=`{}`'.format(dtype_name))

if dtype is None:
if isinstance(start, float) or isinstance(end, float) or isinstance(step, float):
dtype = mstype.float32


+ 1
- 0
ms_adapter/pytorch/cuda/__init__.py View File

@@ -3,6 +3,7 @@
import mindspore as ms
from mindspore.ops import constexpr
from ms_adapter.utils import get_backend
from ms_adapter.pytorch.tensor import FloatTensor, LongTensor

def is_available():
backend = get_backend()


+ 6
- 1
ms_adapter/pytorch/functional.py View File

@@ -66,6 +66,8 @@ def cat(tensors, dim=0, *, out=None):
return out
return cast_to_adapter_tensor(output)

def concat(tensors, dim=0, *, out=None):
return cat(tensors, dim, out=out)

def ones(*size, out=None, dtype=None, layout=None,
device=None, requires_grad=False):
@@ -175,7 +177,7 @@ def zeros(*size, out=None, dtype=None, device=None, requires_grad=False):
unsupported_attr(requires_grad)

if isinstance(size[0], (tuple, list)):
size = size[0]
size = tuple(size[0])

if len(size) < 2:
raise ValueError("Until now, For 'ms_adapter.pytorch.zeros', the size of `size` sholud bigger than 1, "
@@ -1367,3 +1369,6 @@ def bitwise_right_shift(input, other, *, out=None):
other = cast_to_ms_tensor(other).asnumpy()
output = ms.Tensor(np.right_shift(input, other))
return cast_to_adapter_tensor(output)

def from_numpy(np_data):
return cast_to_adapter_tensor(ms.Tensor.from_numpy(np_data))

+ 53
- 4
ms_adapter/pytorch/tensor.py View File

@@ -7,6 +7,7 @@ import mindspore as ms
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
import mindspore.ops as P
from mindspore.ops import constexpr
from mindspore.common.initializer import _assignment, _init_random_normal, _init_random_uniform
from mindspore._checkparam import Validator as validator
from ms_adapter.utils import unsupported_attr, pynative_mode_condition, is_under_gpu_context, get_backend
@@ -63,6 +64,10 @@ class Tensor(ms.Tensor):
output = ms.Tensor(_init_random_uniform(from_alias, to, self.shape), ms.float32)
_assignment(self, output)

def uniform(self, from_alias=0, to=1):
output = ms.Tensor(_init_random_uniform(from_alias, to, self.shape), ms.float32)
return cast_to_adapter_tensor(output)

def zero_(self):
output = tensor_operator_registry.get("fill")(self.dtype, self.shape, 0.0)
_assignment(self, output)
@@ -127,10 +132,19 @@ class Tensor(ms.Tensor):
return self

def expand(self, *size):
self._init_check()
@constexpr
def size_to_ms_tensor(size):
if isinstance(size[0], (list, tuple)):
size = ms.Tensor(size[0])
else:
size = ms.Tensor(size)
return size
input_ms = cast_to_ms_tensor(self)
size = ms.Tensor(size)
return cast_to_adapter_tensor(input_ms.expand(size))
# TODO: to support int64
if input_ms.dtype == mstype.int64:
input_ms = ms.ops.cast(input_ms, mstype.int32)
_size = size_to_ms_tensor(size)
return cast_to_adapter_tensor(input_ms.expand(_size))

def sigmoid(self):
input = cast_to_ms_tensor(self)
@@ -241,6 +255,32 @@ class Tensor(ms.Tensor):
res += initial
return cast_to_adapter_tensor(res.astype(dtype))

def mean(self, dim=None, keepdim=False, dtype=None):
if dim is None:
axis = ()
else:
axis = dim

frelam marked this conversation as resolved
zoulq commented 1 year ago
Review
这个检查是否必须?如果mindspore对应接口会检查就无需在Adapter再检查
input = cast_to_adapter_tensor(self)
if dtype:
input = self.astype(dtype)

output = ms.ops.mean(input, axis, keepdim)
return cast_to_adapter_tensor(output)

def prod(self, dim=None, keepdim=False, dtype=None):
if dim is None:
axis = ()
else:
axis = dim

frelam marked this conversation as resolved
zoulq commented 1 year ago
Review
同上
input = cast_to_adapter_tensor(self)
if dtype:
input = self.astype(dtype)

output = ms.ops.prod(input, axis, keepdim)
return cast_to_adapter_tensor(output)

def split(self, split_size, dim=0):
if isinstance(split_size, tuple):
raise TypeError("For 'Tensor.split', the type of `split_size` should be int, "
@@ -573,6 +613,13 @@ class Tensor(ms.Tensor):
# TODO: warning for not support inplace index_fill_
return self.index_fill(dim, index, value)

def index_select(self, dim, index):
_input_params = cast_to_ms_tensor(self)
_input_indices = cast_to_ms_tensor(index)

frelam marked this conversation as resolved
zoulq commented 1 year ago
Review
这个赋值可以去掉
output = ms.ops.gather(_input_params, _input_indices, dim)
return cast_to_adapter_tensor(output)

@property
def data(self):
return self.detach()
@@ -599,6 +646,9 @@ class Tensor(ms.Tensor):
warnings.warn(warning)
return self

def is_cuda(self):
return is_under_gpu_context()

def le(self, other):
input = cast_to_ms_tensor(self)
if isinstance(other, Tensor):
@@ -680,7 +730,6 @@ def tensor(data, dtype=None, device=None, requires_grad=False):

return Tensor(input_data=data, dtype=dtype)


def cast_to_ms_tensor(inputs):
"""
Cast MSAdapter.Tensor to MindSpore.Tensor before call mindspore API.


+ 31
- 0
testing/ut/pytorch/functional/test_cat.py View File

@@ -36,7 +36,38 @@ def test_cat3():

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())


def test_concat1():
ms_tensor = ms_torch.tensor([1, 2, 3])
ms_result = ms_torch.concat((ms_tensor, ms_tensor), dim=0)

torch_tensor = torch.tensor([1, 2, 3])
torch_result = torch.concat((torch_tensor, torch_tensor), dim=0)

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())

def test_concat2():
ms_tensor = ms_torch.tensor([[1, 2, 3], [1, 2, 3]])
ms_result = ms_torch.concat((ms_tensor, ms_tensor), dim=1)

torch_tensor = torch.tensor([[1, 2, 3], [1, 2, 3]])
torch_result = torch.concat((torch_tensor, torch_tensor), dim=1)

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())

def test_concat3():
ms_tensor = ms_torch.tensor([[1, 2, 3], [1, 2, 3]])
ms_result = ms_torch.concat([ms_tensor, ms_tensor], dim=0)

torch_tensor = torch.tensor([[1, 2, 3], [1, 2, 3]])
torch_result = torch.concat([torch_tensor, torch_tensor], dim=0)

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())

if __name__ == '__main__':
test_cat1()
test_cat2()
test_cat3()
test_concat1()
test_concat2()
test_concat3()

+ 50
- 1
testing/ut/pytorch/tensor/test_tensor.py View File

@@ -993,6 +993,51 @@ def test_T():

assert np.allclose(ms_out.numpy(), torch_out.numpy())

def test_mean1():
ms_tensor = pytorch.tensor([[1., 2, 3], [1, 2, 3]])
ms_result = ms_tensor.mean(dim=0)

torch_tensor = torch.tensor([[1., 2, 3], [1, 2, 3]])
torch_result = torch_tensor.mean(dim=0)

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())
assert ms_result.asnumpy().dtype == torch_result.numpy().dtype

def test_mean2():
ms_tensor = pytorch.tensor([[1., 2, 3], [1, 2, 3]])
ms_result = ms_tensor.mean(dim=(0, -1))

torch_tensor = torch.tensor([[1., 2, 3], [1, 2, 3]])
torch_result = torch_tensor.mean(dim=(0, -1))

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())
assert ms_result.asnumpy().dtype == torch_result.numpy().dtype

def test_prod1():
ms_tensor = pytorch.tensor([[1., 2, 3], [1, 2, 3]])
ms_result = ms_tensor.prod(dim=0)

torch_tensor = torch.tensor([[1., 2, 3], [1, 2, 3]])
torch_result = torch_tensor.prod(dim=0)

assert np.allclose(ms_result.asnumpy(), torch_result.numpy())
assert ms_result.asnumpy().dtype == torch_result.numpy().dtype

def test_index_select():
data = np.random.randn(3, 4 ,5)

x_torch = torch.tensor(data)
indices = torch.tensor([0, 2])
torch_out = x_torch.index_select(1, indices)

x_ms = pytorch.tensor(data)
indices = pytorch.tensor([0, 2])
ms_out = x_ms.index_select(1, indices)

assert np.allclose(ms_out.asnumpy(), torch_out.numpy())
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype


if __name__ == '__main__':
test_others_tensor()
test_add_()
@@ -1073,4 +1118,8 @@ if __name__ == '__main__':
test_le1()
test_le2()
test_t1()
test_t2()
test_t2()
test_mean1()
test_mean2()
test_prod1()
test_index_select()

+ 72
- 0
third_party/einops/patch/0001-adapt-to-ms_adapter.pytorch.patch View File

@@ -0,0 +1,72 @@
From ed543c49ec315e1e8cb9c6f8cb0fbc9a08deaad1 Mon Sep 17 00:00:00 2001
From: lvhaoyu <lvhaoyu@huawei.com>
Date: Tue, 22 Nov 2022 15:03:16 +0800
Subject: [PATCH] adapt to ms_adapter.pytorch

---
einops/_backends.py | 4 ++--
einops/_torch_specific.py | 2 +-
einops/layers/torch.py | 6 +++---
3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/einops/_backends.py b/einops/_backends.py
index ba6fca8..80cce17 100644
--- a/einops/_backends.py
+++ b/einops/_backends.py
@@ -316,7 +316,7 @@ class TorchBackend(AbstractBackend):
framework_name = 'torch'
def __init__(self):
- import torch
+ import ms_adapter.pytorch as torch
self.torch = torch
def is_appropriate_type(self, tensor):
@@ -677,4 +677,4 @@ class OneFlowBackend(AbstractBackend):
return oneflow
def einsum(self, pattern, *x):
- return self.flow.einsum(pattern, *x)
\ No newline at end of file
+ return self.flow.einsum(pattern, *x)
diff --git a/einops/_torch_specific.py b/einops/_torch_specific.py
index 204d935..5ea9253 100644
--- a/einops/_torch_specific.py
+++ b/einops/_torch_specific.py
@@ -10,7 +10,7 @@ Importantly, whole lib is designed so that you can't use it
from typing import Dict, List
-import torch
+import ms_adapter.pytorch as torch
from einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
diff --git a/einops/layers/torch.py b/einops/layers/torch.py
index 3199241..cb9af63 100644
--- a/einops/layers/torch.py
+++ b/einops/layers/torch.py
@@ -1,6 +1,6 @@
from typing import Optional, Dict, cast
-import torch
+import ms_adapter.pytorch as torch
from . import RearrangeMixin, ReduceMixin
from ._einmix import _EinmixMixin
@@ -29,10 +29,10 @@ class Reduce(ReduceMixin, torch.nn.Module):
class EinMix(_EinmixMixin, torch.nn.Module):
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
- self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform_(-weight_bound, weight_bound),
+ self.weight = torch.nn.Parameter(torch.zeros(weight_shape).uniform(-weight_bound, weight_bound),
requires_grad=True)
if bias_shape is not None:
- self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform_(-bias_bound, bias_bound),
+ self.bias = torch.nn.Parameter(torch.zeros(bias_shape).uniform(-bias_bound, bias_bound),
requires_grad=True)
else:
self.bias = None
--
2.25.1


+ 10
- 0
third_party/einops/setup.sh View File

@@ -0,0 +1,10 @@
#! bin/bash
if [ -d "einops" ];then
echo "[einops setup]: 'einops' dir has exist, please check and remove it before installing einops"
exit 1
fi
git clone https://github.com/arogozhnikov/einops.git || exit 1
cd einops
git checkout v0.6.0 -b v0.6.0 || exit 1
patch -p1 < ../patch/0001-adapt-to-ms_adapter.pytorch.patch || exit 1
pip install -e ./ || exit 1

Loading…
Cancel
Save