#785 [0.2]searchsorted value support scalar,meitu sd support syntax

Merged
zoulq merged 1 commits from 543145646twc/MSAdapter:develop-twc-release_0.2 into release_0.2 1 year ago
  1. +1
    -0
      mindtorch/torch/__init__.py
  2. +14
    -0
      mindtorch/torch/_utils.py
  3. +3
    -1
      mindtorch/torch/functional.py
  4. +8
    -0
      mindtorch/torch/jit/__init__.py
  5. +7
    -0
      mindtorch/torch/nn/modules/conv.py
  6. +3
    -1
      mindtorch/torch/nn/parallel/__init__.py
  7. +12
    -0
      mindtorch/torch/nn/parallel/_functions.py
  8. +12
    -0
      mindtorch/torch/nn/parallel/data_parallel.py
  9. +4
    -0
      mindtorch/torch/nn/parallel/distributed.py
  10. +8
    -0
      mindtorch/torch/tensor_type.py
  11. +1
    -0
      mindtorch/torch/utils/__init__.py
  12. +22
    -0
      mindtorch/torch/utils/cpp_extension.py
  13. +10
    -0
      testing/ut/pytorch/functional/test_math.py

+ 1
- 0
mindtorch/torch/__init__.py View File

@@ -7,6 +7,7 @@ from mindtorch.torch.types import *
from mindtorch.torch._C import *
from mindtorch.torch.common import *
from mindtorch.torch.tensor import *
from mindtorch.torch.tensor_type import *
from mindtorch.torch import nn
from mindtorch.torch import optim
from mindtorch.torch.functional import *


+ 14
- 0
mindtorch/torch/_utils.py View File

@@ -113,3 +113,17 @@ def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
return _ClassPropertyDescriptor(func)

def _flatten_dense_tensors(tensors):
unsupported_attr(tensors)
raise NotImplementedError("`_flatten_dense_tensors` is not implemented now.")

def _take_tensors(tensors, size_limit):
unsupported_attr(tensors)
unsupported_attr(size_limit)
raise NotImplementedError("`_take_tensors` is not implemented now.")

def _unflatten_dense_tensors(flat, tensors):
unsupported_attr(flat)
unsupported_attr(tensors)
raise NotImplementedError("`_unflatten_dense_tensors` is not implemented now.")

+ 3
- 1
mindtorch/torch/functional.py View File

@@ -1520,7 +1520,9 @@ def searchsorted(sorted_sequence, value, *, out_int32=False, right=False, side='
value = cast_to_ms_tensor(value)
if sorted_sequence.dtype == ms.float16:
sorted_sequence = sorted_sequence.astype(ms.float32)
if value.dtype == ms.float16:
if isinstance(value, float):
value = ms.Tensor([value, ]).astype(ms.float32)
elif value.dtype == ms.float16:
value = value.astype(ms.float32)
output = ms.ops.searchsorted(sorted_sequence, value, out_int32=out_int32, right=right)
return _out_inplace_assign(out, output, "searchsorted")


+ 8
- 0
mindtorch/torch/jit/__init__.py View File

@@ -32,3 +32,11 @@ def ignore(drop=False, **kwargs):
return fn

return decorator

def _overload_method(func):
unsupported_attr(func)
warning("`jit._overload_method` is an empty function that has not implemented now.")

def interface(obj):
unsupported_attr(obj)
warning("`jit.interface` is an empty function that has not implemented now.")

+ 7
- 0
mindtorch/torch/nn/modules/conv.py View File

@@ -613,3 +613,10 @@ class ConvTranspose3d(_ConvTransposeNd):
output = conv_transpose3d(input, self.weight, self.bias, self.stride,
self.padding, _out_padding, self.groups, self.dilation)
return cast_to_adapter_tensor(output)


class _ConvTransposeMixin(_ConvTransposeNd):
def __init__(self, *args, **kwargs):
unsupported_attr(args)
unsupported_attr(kwargs)
raise NotImplementedError("`_ConvTransposeMixin` is not implemented now.")

+ 3
- 1
mindtorch/torch/nn/parallel/__init__.py View File

@@ -1,5 +1,7 @@
from .distributed import DistributedDataParallel
from .data_parallel import DataParallel

__all__ = [
'DistributedDataParallel'
'DistributedDataParallel',
'DataParallel'
]

+ 12
- 0
mindtorch/torch/nn/parallel/_functions.py View File

@@ -0,0 +1,12 @@
from mindtorch import unsupported_attr
from mindtorch.torch.autograd import Function


class Scatter(Function):
def __init__(self):
super().__init__()
raise NotImplementedError("`Scatter` is not implemented now.")

def _get_stream(device: int):
unsupported_attr(device)
raise NotImplementedError("`_get_stream` is not implemented now.")

+ 12
- 0
mindtorch/torch/nn/parallel/data_parallel.py View File

@@ -0,0 +1,12 @@
from mindtorch.utils import unsupported_attr
from mindtorch.torch.nn.modules.module import Module


class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super().__init__()
unsupported_attr(module)
unsupported_attr(device_ids)
unsupported_attr(output_device)
unsupported_attr(dim)
raise NotImplementedError("`DataParallel` is not implemented now.")

+ 4
- 0
mindtorch/torch/nn/parallel/distributed.py View File

@@ -77,3 +77,7 @@ class DistributedDataParallel(Module):
unsupported_attr(inputs)
unsupported_attr(kwargs)
unsupported_attr(device_ids)

def _find_tensors(obj):
unsupported_attr(obj)
raise NotImplementedError("`_find_tensors` is not implemented now.")

+ 8
- 0
mindtorch/torch/tensor_type.py View File

@@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from mindtorch.utils import unsupported_attr

class TensorType:
def __init__(self, dim):
unsupported_attr(dim)
raise NotImplementedError("`TensorType` is not implemented now.")

+ 1
- 0
mindtorch/torch/utils/__init__.py View File

@@ -1,3 +1,4 @@
from mindtorch.torch.utils import _pytree
from mindtorch.torch.utils import data
from mindtorch.torch.utils import checkpoint
from mindtorch.torch.utils import cpp_extension

+ 22
- 0
mindtorch/torch/utils/cpp_extension.py View File

@@ -0,0 +1,22 @@
from mindtorch.utils import unsupported_attr


class BuildExtension():
def __init__(self):
raise NotImplementedError("`BuildExtension` is not implemented now.")


def CppExtension(name, sources, *args, **kwargs):
unsupported_attr(name)
unsupported_attr(sources)
unsupported_attr(args)
unsupported_attr(kwargs)
raise NotImplementedError("`CppExtension` is not implemented now.")


def CUDAExtension(name, sources, *args, **kwargs):
unsupported_attr(name)
unsupported_attr(sources)
unsupported_attr(args)
unsupported_attr(kwargs)
raise NotImplementedError("`CUDAExtension` is not implemented now.")

+ 10
- 0
testing/ut/pytorch/functional/test_math.py View File

@@ -1787,6 +1787,16 @@ def test_searchsorted_fp16():
ms_out = ms_torch.searchsorted(ms_seq, ms_val)
param_compare(torch_out, ms_out)

def test_searchsorted_scalar():
np_seq = np.array([1, 3, 5, 7, 9]).astype(np.float16)
val = 3.0
torch_seq = torch.tensor(np_seq)
ms_seq = ms_torch.tensor(np_seq)

torch_out = torch.searchsorted(torch_seq, val)
ms_out = ms_torch.searchsorted(ms_seq, val)
assert (torch_out.numpy() == ms_out.numpy()).all()

def test_sgn():
np_array1 = np.array([[-3, -2, -0.0, 0.0, 2, 3]]).astype(np.float16)
np_array2 = np.array([[-3, -2, -0, 0, 2, 3]]).astype(np.int16)


Loading…
Cancel
Save