#426 changes to Tensor.sum

Merged
Erpim merged 4 commits from lzh_sum into master 1 year ago
  1. +2
    -5
      msadapter/pytorch/tensor.py
  2. +24
    -0
      testing/ut/pytorch/tensor/test_tensor.py

+ 2
- 5
msadapter/pytorch/tensor.py View File

@@ -867,8 +867,8 @@ class Tensor(ms.Tensor):
def sum(self, dim=None, keepdim=False, dtype=None):
input = cast_to_ms_tensor(self)
if dtype is not None:
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.int32)
elif self.dtype in msdapter_dtype.all_int_type:
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.bool_).astype(mstype.int64)
lzh commented 1 year ago
Review
tensor([-1,1]),sum(dtype=bool)的情况,torch会把-1和1当作bool型处理,结果为True。ms.ops.sum不能处理bool型,如果先转换为int型,-1和1的sum为0,结果会返回False,所以需要先转换成bool再转int。
elif input.dtype in (msdapter_dtype.all_int_type, mstype.bool_):
dtype = mstype.int64
input = input.astype(dtype)

@@ -880,14 +880,11 @@ class Tensor(ms.Tensor):
else:
dim = validator.check_and_canonicalize_axes(dim, input.ndim)

if not validator.check_type_support(input.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)):
input = input.astype(mstype.float32)
if 0 in self.shape:
input = tensor_operator_registry.get('make_tensor')([0], input.dtype)
res = tensor_operator_registry.get('sum')(bool(keepdim))(input, dim)
if dtype == mstype.bool_:
res = res.astype(mstype.bool_)

return cast_to_adapter_tensor(res)

def sum_to_size(self, *size):


+ 24
- 0
testing/ut/pytorch/tensor/test_tensor.py View File

@@ -414,6 +414,29 @@ def test_sum2():
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert ms_output.asnumpy().dtype == torch_output.numpy().dtype

def test_sum3():

torch_tensor = torch.tensor([True, True, False])
ms_tensor = pytorch.tensor([True, True, False])
torch_output = torch_tensor.sum()
ms_output = ms_tensor.sum()
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert ms_output.asnumpy().dtype == torch_output.numpy().dtype

torch_tensor = torch.tensor([-1, 1], dtype=torch.bool)
ms_tensor = pytorch.tensor([-1, 1], dtype=pytorch.bool)
torch_output = torch_tensor.sum()
ms_output = ms_tensor.sum()
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert ms_output.asnumpy().dtype == torch_output.numpy().dtype

torch_tensor = torch.tensor([-1, 1])
ms_tensor = pytorch.tensor([-1, 1])
torch_output = torch_tensor.sum(dtype=torch.bool)
ms_output = ms_tensor.sum(dtype=pytorch.bool)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert ms_output.asnumpy().dtype == torch_output.numpy().dtype

def test_split():

tensor = np.random.random((3, 3)).astype(np.float32)
@@ -4907,6 +4930,7 @@ if __name__ == '__main__':
test_numel()
test_sum()
test_sum2()
test_sum3()
test_split()
test_numpy()
test_ndimension()


Loading…
Cancel
Save