#178 fix nn.functional.linear and add ShortTensor

Merged
zoulq merged 1 commits from frelam/MSAdapter:master20221122-2 into master 1 year ago
  1. +20
    -2
      ms_adapter/pytorch/nn/functional.py
  2. +5
    -0
      ms_adapter/pytorch/tensor.py
  3. +65
    -0
      testing/ut/pytorch/nn/functional/test_linear.py
  4. +11
    -3
      testing/ut/pytorch/nn/test_linear.py
  5. +20
    -12
      testing/ut/pytorch/tensor/test_tensor.py

+ 20
- 2
ms_adapter/pytorch/nn/functional.py View File

@@ -1509,9 +1509,27 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,


def linear(input, weight, bias=None):
@constexpr
def get_transpose_perm(shape):
_rank = len(shape)
perm = list(i for i in range(_rank))
_tmp = perm[-1]
perm[-1] = perm[-2]
perm[-2] = _tmp
return tuple(perm)

weight_shape = weight.shape
weight_rank = len(weight_shape)
if weight_rank not in (1, 2):
raise ValueError("For nn.functional.linear, weight only support 2D or 1D input"
f"but got {weight_rank}D input")

if weight_rank == 2:
weight = ms.ops.transpose(weight, get_transpose_perm(weight_shape))

input = cast_to_ms_tensor(input)
output = ms.ops.MatMul(transpose_b=True)(input, weight)
output = ms.ops.matmul(input, weight)
if bias is not None:
output = ms.ops.bias_add(output, bias)
output = ms.ops.add(output, bias)
output = cast_to_adapter_tensor(output)
return output

+ 5
- 0
ms_adapter/pytorch/tensor.py View File

@@ -985,6 +985,11 @@ class CharTensor(_TypeTensor):
super(CharTensor, self).__init__(*input_data, dtype_name='int8')


class ShortTensor(_TypeTensor):
def __init__(self, *input_data):
super(ShortTensor, self).__init__(*input_data, dtype_name='int16')


class IntTensor(_TypeTensor):
def __init__(self, *input_data):
super(IntTensor, self).__init__(*input_data, dtype_name='int32')


+ 65
- 0
testing/ut/pytorch/nn/functional/test_linear.py View File

@@ -0,0 +1,65 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import mindspore as ms
import ms_adapter.pytorch as ms_torch
import torch
import numpy as np
from mindspore import context

context.set_context(mode=ms.PYNATIVE_MODE)

def test_linear1():
data = np.ones((1, 2, 5))
weight = np.ones((3, 5))

ms_data = ms_torch.tensor(data)
ms_weight = ms_torch.tensor(weight)
ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight)

torch_data = torch.tensor(data)
torch_weight = torch.tensor(weight)
torch_out = torch.nn.functional.linear(torch_data, torch_weight)

assert ms_out.shape == torch_out.shape
assert ms_out.numpy().dtype == torch_out.numpy().dtype

def test_linear2():
data = np.ones((1, 2, 5))
weight = np.ones((3, 5))
bias = np.ones((3))

ms_data = ms_torch.tensor(data)
ms_weight = ms_torch.tensor(weight)
ms_bias = ms_torch.tensor(bias)
ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight, bias=ms_bias)

torch_data = torch.tensor(data)
torch_weight = torch.tensor(weight)
torch_bias = torch.tensor(bias)
torch_out = torch.nn.functional.linear(torch_data, torch_weight, bias=torch_bias)

assert ms_out.shape == torch_out.shape
assert ms_out.numpy().dtype == torch_out.numpy().dtype


def test_linear3():
data = np.ones((1, 2, 5))
weight = np.ones((5))

ms_data = ms_torch.tensor(data)
ms_weight = ms_torch.tensor(weight)
ms_out = ms_torch.nn.functional.linear(ms_data, ms_weight)

torch_data = torch.tensor(data)
torch_weight = torch.tensor(weight)
torch_out = torch.nn.functional.linear(torch_data, torch_weight)

assert ms_out.shape == torch_out.shape
assert ms_out.numpy().dtype == torch_out.numpy().dtype


if __name__ == '__main__':
test_linear1()
test_linear2()
test_linear3()

+ 11
- 3
testing/ut/pytorch/nn/test_linear.py View File

@@ -88,6 +88,14 @@ def test_bilinear_model():
assert output.shape == (10, 7)


test_linear_model()
test_identity()
test_bilinear_model()
def test_linear_model2():
linear = Linear(64, 3)
x = tensor(np.ones((1, 2, 64)))
assert linear(x).shape == (1, 2, 3)


if __name__ == '__main__':
test_linear_model()
test_identity()
test_bilinear_model()
test_linear_model2()

+ 20
- 12
testing/ut/pytorch/tensor/test_tensor.py View File

@@ -600,50 +600,58 @@ def test_others_tensor():
tensor = pytorch.ByteTensor()
tensor = pytorch.ByteTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.uint8
assert tensor.dtype == pytorch.uint8
tensor = pytorch.ByteTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.uint8
assert tensor.dtype == pytorch.uint8

tensor = pytorch.CharTensor()
tensor = pytorch.CharTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.int8
assert tensor.dtype == pytorch.int8
tensor = pytorch.CharTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.int8
assert tensor.dtype == pytorch.int8

tensor = pytorch.ShortTensor()
tensor = pytorch.ShortTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == pytorch.int16
tensor = pytorch.ShortTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == pytorch.int16
frelam commented 1 year ago
Review
done

tensor = pytorch.IntTensor()
tensor = pytorch.IntTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.int32
assert tensor.dtype == pytorch.int32
tensor = pytorch.IntTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.int32
assert tensor.dtype == pytorch.int32

tensor = pytorch.HalfTensor()
tensor = pytorch.HalfTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.float16
assert tensor.dtype == pytorch.float16
tensor = pytorch.HalfTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.float16
assert tensor.dtype == pytorch.float16

tensor = pytorch.FloatTensor()
tensor = pytorch.FloatTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.float32
assert tensor.dtype == pytorch.float32
tensor = pytorch.FloatTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.float32
assert tensor.dtype == pytorch.float32

tensor = pytorch.DoubleTensor()
tensor = pytorch.DoubleTensor(3, 5)
assert tensor.shape == (3, 5)
assert tensor.dtype == ms.float64
assert tensor.dtype == pytorch.float64
tensor = pytorch.DoubleTensor([1, 2, 3])
assert tensor.shape == (3,)
assert tensor.dtype == ms.float64
assert tensor.dtype == pytorch.float64

def test_is_floating_point():
x = [1, 2, -1, 2, 0, -3.5]


Loading…
Cancel
Save