diff --git a/msadapter/pytorch/tensor.py b/msadapter/pytorch/tensor.py index 3bd0611f..cee3754c 100644 --- a/msadapter/pytorch/tensor.py +++ b/msadapter/pytorch/tensor.py @@ -867,8 +867,8 @@ class Tensor(ms.Tensor): def sum(self, dim=None, keepdim=False, dtype=None): input = cast_to_ms_tensor(self) if dtype is not None: - input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.int32) - elif self.dtype in msdapter_dtype.all_int_type: + input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.bool_).astype(mstype.int64) + elif input.dtype in (msdapter_dtype.all_int_type, mstype.bool_): dtype = mstype.int64 input = input.astype(dtype) @@ -880,14 +880,11 @@ class Tensor(ms.Tensor): else: dim = validator.check_and_canonicalize_axes(dim, input.ndim) - if not validator.check_type_support(input.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)): - input = input.astype(mstype.float32) if 0 in self.shape: input = tensor_operator_registry.get('make_tensor')([0], input.dtype) res = tensor_operator_registry.get('sum')(bool(keepdim))(input, dim) if dtype == mstype.bool_: res = res.astype(mstype.bool_) - return cast_to_adapter_tensor(res) def sum_to_size(self, *size): diff --git a/testing/ut/pytorch/tensor/test_tensor.py b/testing/ut/pytorch/tensor/test_tensor.py index 451ae5ed..c5fb91e2 100644 --- a/testing/ut/pytorch/tensor/test_tensor.py +++ b/testing/ut/pytorch/tensor/test_tensor.py @@ -414,6 +414,29 @@ def test_sum2(): assert np.allclose(ms_output.asnumpy(), torch_output.numpy()) assert ms_output.asnumpy().dtype == torch_output.numpy().dtype +def test_sum3(): + + torch_tensor = torch.tensor([True, True, False]) + ms_tensor = pytorch.tensor([True, True, False]) + torch_output = torch_tensor.sum() + ms_output = ms_tensor.sum() + assert np.allclose(ms_output.asnumpy(), torch_output.numpy()) + assert ms_output.asnumpy().dtype == torch_output.numpy().dtype + + torch_tensor = torch.tensor([-1, 1], dtype=torch.bool) + ms_tensor = pytorch.tensor([-1, 1], dtype=pytorch.bool) + torch_output = torch_tensor.sum() + ms_output = ms_tensor.sum() + assert np.allclose(ms_output.asnumpy(), torch_output.numpy()) + assert ms_output.asnumpy().dtype == torch_output.numpy().dtype + + torch_tensor = torch.tensor([-1, 1]) + ms_tensor = pytorch.tensor([-1, 1]) + torch_output = torch_tensor.sum(dtype=torch.bool) + ms_output = ms_tensor.sum(dtype=pytorch.bool) + assert np.allclose(ms_output.asnumpy(), torch_output.numpy()) + assert ms_output.asnumpy().dtype == torch_output.numpy().dtype + def test_split(): tensor = np.random.random((3, 3)).astype(np.float32) @@ -4907,6 +4930,7 @@ if __name__ == '__main__': test_numel() test_sum() test_sum2() + test_sum3() test_split() test_numpy() test_ndimension()