From d486f8716e23e43b30ea5f668ff5e411d46b7e61 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Fri, 25 Nov 2022 15:58:01 +0800 Subject: [PATCH] fix nn.functional.linear --- ms_adapter/pytorch/nn/functional.py | 22 ++++++- ms_adapter/pytorch/tensor.py | 5 ++ .../ut/pytorch/nn/functional/test_linear.py | 65 +++++++++++++++++++ testing/ut/pytorch/nn/test_linear.py | 14 +++- testing/ut/pytorch/tensor/test_tensor.py | 32 +++++---- 5 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 testing/ut/pytorch/nn/functional/test_linear.py diff --git a/ms_adapter/pytorch/nn/functional.py b/ms_adapter/pytorch/nn/functional.py index f4bdb668..3016fd0b 100644 --- a/ms_adapter/pytorch/nn/functional.py +++ b/ms_adapter/pytorch/nn/functional.py @@ -1509,9 +1509,27 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, def linear(input, weight, bias=None): + @constexpr + def get_transpose_perm(shape): + _rank = len(shape) + perm = list(i for i in range(_rank)) + _tmp = perm[-1] + perm[-1] = perm[-2] + perm[-2] = _tmp + return tuple(perm) + + weight_shape = weight.shape + weight_rank = len(weight_shape) + if weight_rank not in (1, 2): + raise ValueError("For nn.functional.linear, weight only support 2D or 1D input" + f"but got {weight_rank}D input") + + if weight_rank == 2: + weight = ms.ops.transpose(weight, get_transpose_perm(weight_shape)) + input = cast_to_ms_tensor(input) - output = ms.ops.MatMul(transpose_b=True)(input, weight) + output = ms.ops.matmul(input, weight) if bias is not None: - output = ms.ops.bias_add(output, bias) + output = ms.ops.add(output, bias) output = cast_to_adapter_tensor(output) return output diff --git a/ms_adapter/pytorch/tensor.py b/ms_adapter/pytorch/tensor.py index 9b4e625a..355e37ef 100644 --- a/ms_adapter/pytorch/tensor.py +++ b/ms_adapter/pytorch/tensor.py @@ -985,6 +985,11 @@ class CharTensor(_TypeTensor): super(CharTensor, self).__init__(*input_data, dtype_name='int8') +class ShortTensor(_TypeTensor): + def __init__(self, *input_data): + super(ShortTensor, self).__init__(*input_data, dtype_name='int16') + + class IntTensor(_TypeTensor): def __init__(self, *input_data): super(IntTensor, self).__init__(*input_data, dtype_name='int32') diff --git a/testing/ut/pytorch/nn/functional/test_linear.py b/testing/ut/pytorch/nn/functional/test_linear.py new file mode 100644 index 00000000..8a655a5c --- /dev/null +++ b/testing/ut/pytorch/nn/functional/test_linear.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import mindspore as ms +import ms_adapter.pytorch as ms_torch +import torch +import numpy as np +from mindspore import context + +context.set_context(mode=ms.PYNATIVE_MODE) + +def test_linear1(): + data = np.ones((1, 2, 5)) + weight = np.ones((3, 5)) + + ms_data = ms_torch.tensor(data) + ms_weight = ms_torch.tensor(weight) + ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight) + + torch_data = torch.tensor(data) + torch_weight = torch.tensor(weight) + torch_out = torch.nn.functional.linear(torch_data, torch_weight) + + assert ms_out.shape == torch_out.shape + assert ms_out.numpy().dtype == torch_out.numpy().dtype + +def test_linear2(): + data = np.ones((1, 2, 5)) + weight = np.ones((3, 5)) + bias = np.ones((3)) + + ms_data = ms_torch.tensor(data) + ms_weight = ms_torch.tensor(weight) + ms_bias = ms_torch.tensor(bias) + ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight, bias=ms_bias) + + torch_data = torch.tensor(data) + torch_weight = torch.tensor(weight) + torch_bias = torch.tensor(bias) + torch_out = torch.nn.functional.linear(torch_data, torch_weight, bias=torch_bias) + + assert ms_out.shape == torch_out.shape + assert ms_out.numpy().dtype == torch_out.numpy().dtype + + +def test_linear3(): + data = np.ones((1, 2, 5)) + weight = np.ones((5)) + + ms_data = ms_torch.tensor(data) + ms_weight = ms_torch.tensor(weight) + ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight) + + torch_data = torch.tensor(data) + torch_weight = torch.tensor(weight) + torch_out = torch.nn.functional.linear(torch_data, torch_weight) + + assert ms_out.shape == torch_out.shape + assert ms_out.numpy().dtype == torch_out.numpy().dtype + + +if __name__ == '__main__': + test_linear1() + test_linear2() + test_linear3() \ No newline at end of file diff --git a/testing/ut/pytorch/nn/test_linear.py b/testing/ut/pytorch/nn/test_linear.py index f2bc409f..789c7e41 100644 --- a/testing/ut/pytorch/nn/test_linear.py +++ b/testing/ut/pytorch/nn/test_linear.py @@ -88,6 +88,14 @@ def test_bilinear_model(): assert output.shape == (10, 7) -test_linear_model() -test_identity() -test_bilinear_model() +def test_linear_model2(): + linear = Linear(64, 3) + x = tensor(np.ones((1, 2, 64))) + assert linear(x).shape == (1, 2, 3) + + +if __name__ == '__main__': + test_linear_model() + test_identity() + test_bilinear_model() + test_linear_model2() diff --git a/testing/ut/pytorch/tensor/test_tensor.py b/testing/ut/pytorch/tensor/test_tensor.py index 13b25d5b..02435d9f 100644 --- a/testing/ut/pytorch/tensor/test_tensor.py +++ b/testing/ut/pytorch/tensor/test_tensor.py @@ -600,50 +600,58 @@ def test_others_tensor(): tensor = pytorch.ByteTensor() tensor = pytorch.ByteTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.uint8 + assert tensor.dtype == pytorch.uint8 tensor = pytorch.ByteTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.uint8 + assert tensor.dtype == pytorch.uint8 tensor = pytorch.CharTensor() tensor = pytorch.CharTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.int8 + assert tensor.dtype == pytorch.int8 tensor = pytorch.CharTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.int8 + assert tensor.dtype == pytorch.int8 + + tensor = pytorch.ShortTensor() + tensor = pytorch.ShortTensor(3, 5) + assert tensor.shape == (3, 5) + assert tensor.dtype == pytorch.int16 + tensor = pytorch.ShortTensor([1, 2, 3]) + assert tensor.shape == (3,) + assert tensor.dtype == pytorch.int16 tensor = pytorch.IntTensor() tensor = pytorch.IntTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.int32 + assert tensor.dtype == pytorch.int32 tensor = pytorch.IntTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.int32 + assert tensor.dtype == pytorch.int32 tensor = pytorch.HalfTensor() tensor = pytorch.HalfTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.float16 + assert tensor.dtype == pytorch.float16 tensor = pytorch.HalfTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.float16 + assert tensor.dtype == pytorch.float16 tensor = pytorch.FloatTensor() tensor = pytorch.FloatTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.float32 + assert tensor.dtype == pytorch.float32 tensor = pytorch.FloatTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.float32 + assert tensor.dtype == pytorch.float32 tensor = pytorch.DoubleTensor() tensor = pytorch.DoubleTensor(3, 5) assert tensor.shape == (3, 5) - assert tensor.dtype == ms.float64 + assert tensor.dtype == pytorch.float64 tensor = pytorch.DoubleTensor([1, 2, 3]) assert tensor.shape == (3,) - assert tensor.dtype == ms.float64 + assert tensor.dtype == pytorch.float64 def test_is_floating_point(): x = [1, 2, -1, 2, 0, -3.5] -- 2.34.1