#426 changes to Tensor.sum

Merged
Erpim merged 4 commits from lzh_sum into master 1 year ago
lzh commented 1 year ago
lzh reviewed 1 year ago
lzh left a comment
Erpim reviewed 1 year ago
msadapter/pytorch/tensor.py
@@ -854,2 +854,2 @@
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.int32)
elif self.dtype in msdapter_dtype.all_int_type:
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.int64)
elif input.dtype in (msdapter_dtype.all_int_type, mstype.bool_):
Erpim commented 1 year ago
原始类型为bool,输出类型为int;dtype指定为bool,输出类型为bool >>> a = torch.tensor([True, True, True, False, False]).to(torch.bool) >>> a.sum() tensor(3) >>> a.sum(dtype=torch.bool) tensor(True)
Erpim commented 1 year ago
dtype为bool类型最后还是转为bool输出
lzh reviewed 1 year ago
lzh commented 1 year ago
已修改
lzh reviewed 1 year ago
@@ -869,3 +869,2 @@
if dtype is not None:
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.int32)
elif self.dtype in msdapter_dtype.all_int_type:
input = input.astype(dtype) if dtype != mstype.bool_ else input.astype(mstype.bool_).astype(mstype.int64)
lzh commented 1 year ago
tensor([-1,1]),sum(dtype=bool)的情况,torch会把-1和1当作bool型处理,结果为True。ms.ops.sum不能处理bool型,如果先转换为int型,-1和1的sum为0,结果会返回False,所以需要先转换成bool再转int。
Erpim merged commit c54396b64e into master 1 year ago
lzh deleted branch lzh_sum 1 year ago
The pull request has been merged as c54396b64e.
Sign in to join this conversation.
No reviewers
No Label
No Milestone
No Assignees
2 Participants
Notifications
Due Date

No due date set.

Dependencies

This pull request currently doesn't have any dependencies.

Loading…
There is no content yet.