#196 [众智活动任务1]MaxUnpool

Merged
laich merged 1 commits from mirror_yun/MSAdapter:maxunpool into master 1 year ago
  1. +17
    -0
      ms_adapter/pytorch/nn/functional.py
  2. +4
    -0
      ms_adapter/pytorch/nn/modules/__init__.py
  3. +34
    -0
      ms_adapter/pytorch/nn/modules/unpooling.py
  4. +270
    -0
      testing/ut/pytorch/nn/functional/test_maxunpool.py
  5. +125
    -0
      testing/ut/pytorch/nn/test_unpooling.py

+ 17
- 0
ms_adapter/pytorch/nn/functional.py View File

@@ -1433,6 +1433,23 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
out = _max_pool(input)
return cast_to_adapter_tensor(out)

def max_unpool1d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)
out = ms.ops.max_unpool1d(input, indices, kernel_size, stride, padding, output_size)
return out

def max_unpool2d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)
out = ms.ops.max_unpool2d(input, indices, kernel_size, stride, padding, output_size)
return out

def max_unpool3d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)
out = ms.ops.max_unpool3d(input, indices, kernel_size, stride, padding, output_size)
return cast_to_adapter_tensor(out)

def linear(input, weight, bias=None):
@constexpr


+ 4
- 0
ms_adapter/pytorch/nn/modules/__init__.py View File

@@ -8,6 +8,7 @@ from .conv import *
from .distance import *
from .batchnorm import *
from .pooling import *
from .unpooling import *
from .loss import *
from .padding import *
from .rnn import *
@@ -67,6 +68,9 @@ __all__ = [
'LPPool1d',
'LPPool2d',
'Identity',
'MaxUnpool1d',
'MaxUnpool2d',
'MaxUnpool3d',

'ReLU',
'ReLU6',


+ 34
- 0
ms_adapter/pytorch/nn/modules/unpooling.py View File

@@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import ms_adapter.pytorch.nn.functional as Adapter_F
from .module import Module

__all__ = ['MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d']


class _MaxUnpoolNd(Module):
def __init__(self, kernel_size, stride=None, padding=0):
super(_MaxUnpoolNd, self).__init__()
self.kernel_size = kernel_size
self.stride = stride if (stride is not None) else kernel_size
self.padding = padding
def extra_repr(self) -> str:
return 'kernel_size={}, stride={}, padding={}'.format(
self.kernel_size, self.stride, self.padding
)

class MaxUnpool1d(_MaxUnpoolNd):
def forward(self, input, indices, output_size = None):
return Adapter_F.max_unpool1d(input, indices,
self.kernel_size, self.stride, self.padding, output_size)


class MaxUnpool2d(_MaxUnpoolNd):
def forward(self, input, indices, output_size = None):
return Adapter_F.max_unpool2d(input, indices,
self.kernel_size, self.stride, self.padding, output_size)

class MaxUnpool3d(_MaxUnpoolNd):
def forward(self, input, indices, output_size = None):
return Adapter_F.max_unpool3d(input, indices,
self.kernel_size, self.stride, self.padding, output_size)

+ 270
- 0
testing/ut/pytorch/nn/functional/test_maxunpool.py View File

@@ -0,0 +1,270 @@
import numpy as np
import torch
import mindspore as ms
import ms_adapter.pytorch as ms_torch

ms.context.set_context(mode=ms.PYNATIVE_MODE)

def test_maxunpool1d_with_2dim():
N = np.random.randint(1, 33)
C = np.random.randint(1, 257)
tensor = np.random.randn(N, C).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, C + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool1d_with_3dim():
B = np.random.randint(1, 33)
N = np.random.randint(1, 33)
C = np.random.randint(1, 257)
tensor = np.random.randn(B, N, C).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, C + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

# def test_maxunpool1d_with_3dim_shape():
# B = np.random.randint(1, 33)
# N = np.random.randint(1, 33)
# C = np.random.randint(1, 257)
# tensor = np.random.randn(B, N, C).astype(np.float32)

# torch_tensor = torch.tensor(tensor)
# kernel_size = np.random.randint(1, C + 1)
# padding = np.random.randint(0, kernel_size/2 + 1)
# stride = 1

# torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
# torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding, output_size=torch_tensor.size())

# ms_pooling = ms_torch.tensor(torch_pooling.numpy())
# ms_indices = ms_torch.tensor(torch_indices.numpy())

# ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding, tensor.shape)

# assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool1d_with_2dim_shape():
N = np.random.randint(1, 33)
C = np.random.randint(1, 257)
tensor = np.random.randn(N, C).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, C + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding, output_size=[C])

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding, tensor.shape)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool2d_with_3dim():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 257)
tensor = np.random.randn(C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool2d_with_3dim_shape():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 257)
tensor = np.random.randn(C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding, [H, W])

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding,tensor.shape)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool2d_with_4dim():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
N = np.random.randint(1, 33)
tensor = np.random.randn(N, C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool2d_with_4dim_shape():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
N = np.random.randint(1, 33)
tensor = np.random.randn(N, C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding, torch_tensor.size())

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding, tensor.shape)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool3d_with_4dim():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
D = np.random.randint(1, H + 1)
tensor = np.random.randn(C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool3d_with_4dim_shape():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
D = np.random.randint(1, H + 1)
tensor = np.random.randn(C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding, [D, H, W])

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding, tensor.shape)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool3d_with_5dim():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
D = np.random.randint(1, H + 1)
N = np.random.randint(1, 17)
tensor = np.random.randn(N, C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool3d_with_5dim_shape():
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 129)
D = np.random.randint(1, H + 1)
N = np.random.randint(1, 17)
tensor = np.random.randn(N, C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding, torch_tensor.size())

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding, tensor.shape)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
if __name__ == '__main__':
test_maxunpool1d_with_2dim()
test_maxunpool1d_with_3dim()
test_maxunpool2d_with_3dim()
test_maxunpool2d_with_4dim()
test_maxunpool3d_with_4dim()
test_maxunpool3d_with_5dim()
test_maxunpool1d_with_2dim_shape()
# test_maxunpool1d_with_3dim_shape()
test_maxunpool2d_with_3dim_shape()
test_maxunpool2d_with_4dim_shape()
test_maxunpool3d_with_4dim_shape()
test_maxunpool3d_with_5dim_shape()


+ 125
- 0
testing/ut/pytorch/nn/test_unpooling.py View File

@@ -0,0 +1,125 @@
import numpy as np
import torch

from mindspore import Tensor
from ms_adapter.pytorch.nn import MaxUnpool1d, MaxUnpool2d, MaxUnpool3d

import mindspore as ms
ms.context.set_context(mode=ms.GRAPH_MODE)

def test_maxunpool1d_compare1():
kernel_size, stride, padding = 4, 2, 2
ms_net = MaxUnpool1d(kernel_size, stride, padding)
torch_net = torch.nn.MaxUnpool1d(kernel_size, stride, padding)

B, N, C = 4, 5, 6
data = np.random.random([B, N, C])
indices_range = (C - 1) * stride + kernel_size - 2 * padding
indices = np.random.choice(indices_range - 1, size=(1, 1, C), replace=False)
indices = indices.repeat(B, 0).repeat(N, 1)

ms_input = Tensor(data)
ms_indices = Tensor(indices)
torch_input = torch.Tensor(data)
torch_indices = torch.Tensor(indices).type(torch.int64)

torch_output = torch_net(torch_input, torch_indices)
ms_output = ms_net(ms_input, ms_indices)

assert np.allclose(ms_output.numpy(), torch_output.numpy())

def test_maxunpool2d_compare1():
kernel_size, stride = 3, 2
ms_net = MaxUnpool2d(kernel_size, stride)
torch_net = torch.nn.MaxUnpool2d(kernel_size, stride)

B, N, H, W = 4, 5, 6, 7
data = np.random.random([B, N, H, W])
indices_range = (H - 1) * stride + kernel_size
indices_range = ((W - 1) * stride - 1 + kernel_size) * indices_range
indices = np.random.choice(indices_range - 1, size=(1, 1, H * W), replace=False)
indices = indices.repeat(B, 0).repeat(N, 1).reshape(B, N, H, W)
ms_input = Tensor(data)
ms_indices = Tensor(indices)
torch_input = torch.Tensor(data)
torch_indices = torch.Tensor(indices).type(torch.int64)

torch_output = torch_net(torch_input, torch_indices)
ms_output = ms_net(ms_input, ms_indices)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())


def test_maxunpool2d_compare2():
kernel_size, stride, padding = 4, 2, 2
ms_net = MaxUnpool2d(kernel_size, stride, padding)
torch_net = torch.nn.MaxUnpool2d(kernel_size, stride, padding)

B, N, H, W = 6, 7, 8, 9
data = np.random.random([B, N, H, W])
indices_range = (H - 1) * stride + kernel_size - 2 * padding
indices_range = ((W - 1) * stride - 1 + kernel_size - 2 * padding) * indices_range
indices = np.random.choice(indices_range - 1, size=(1, 1, H * W), replace=False)
indices = indices.repeat(B, 0).repeat(N, 1).reshape(B, N, H, W)
ms_input = Tensor(data)
ms_indices = Tensor(indices)
torch_input = torch.Tensor(data)
torch_indices = torch.Tensor(indices).type(torch.int64)

torch_output = torch_net(torch_input, torch_indices)
ms_output = ms_net(ms_input, ms_indices)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_maxunpool2d_compare3():
kernel_size, stride, padding = (3, 5), (3, 1), 0
ms_net = MaxUnpool2d(kernel_size, stride, padding)
torch_net = torch.nn.MaxUnpool2d(kernel_size, stride, padding)

B, N, H, W = 1, 32, 9, 9
data = np.random.random([B, N, H, W])
indices_range = (H - 1) * stride[0] + kernel_size[0]
indices_range = ((W - 1) * stride[1] - 1 + kernel_size[1]) * indices_range

indices = np.random.choice(indices_range - 1, size=(1, 1, H * W), replace=False)
indices = indices.repeat(B, 0).repeat(N, 1).reshape(B, N, H, W)
ms_input = Tensor(data)
ms_indices = Tensor(indices)
torch_input = torch.Tensor(data)
torch_indices = torch.Tensor(indices).type(torch.int64)

torch_output = torch_net(torch_input, torch_indices)
ms_output = ms_net(ms_input, ms_indices)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())


def test_maxunpool3d_compare1():
kernel_size, stride, padding = 4, 2, 2
ms_net = MaxUnpool3d(kernel_size, stride, padding)
torch_net = torch.nn.MaxUnpool3d(kernel_size, stride, padding)

B, C, D, H, W = 4, 5, 6, 6, 6
data = np.random.random([B, C, D, H, W])
indices_range = (D - 1) * stride + kernel_size - 2 * padding
indices_range = ((H - 1) * stride - 1 + kernel_size - 2 * padding) * indices_range
indices_range = ((W - 1) * stride - 1 + kernel_size - 2 * padding) * indices_range

indices = np.random.choice(indices_range - 1, size=(1, 1,D * H * W), replace=False)
indices = indices.repeat(B, 0).repeat(C, 1).reshape(B, C, D, H, W)
ms_input = Tensor(data)
ms_indices = Tensor(indices)
torch_input = torch.Tensor(data)
torch_indices = torch.Tensor(indices).type(torch.int64)

ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

torch_output = torch_net(torch_input, torch_indices)
ms_output = ms_net(ms_input, ms_indices)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())


if __name__ == '__main__':
test_maxunpool1d_compare1()
test_maxunpool2d_compare1()
test_maxunpool2d_compare2()
test_maxunpool2d_compare3()
test_maxunpool3d_compare1()

Loading…
Cancel
Save