#203 add fractional_maxpool2d & fractional_maxpool3d

Merged
laich merged 1 commits from cqu_lxy/MSAdapter:master into master 1 year ago
  1. +24
    -0
      ms_adapter/pytorch/nn/functional.py
  2. +2
    -0
      ms_adapter/pytorch/nn/modules/__init__.py
  3. +53
    -1
      ms_adapter/pytorch/nn/modules/pooling.py
  4. +85
    -1
      testing/ut/pytorch/nn/test_pooling.py

+ 24
- 0
ms_adapter/pytorch/nn/functional.py View File

@@ -1471,3 +1471,27 @@ def lp_pool2d(input, norm_type, kernel_size, stride = None, ceil_mode = False):
input = cast_to_ms_tensor(input)
output = ms.ops.lp_pool2d(input, norm_type, kernel_size, stride, ceil_mode)
return cast_to_adapter_tensor(output)

def fractional_max_pool2d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
input_ms = cast_to_ms_tensor(input_x)
_kernel_size = kernel_size
_output_size = output_size
_output_ratio = output_ratio
_return_indices = return_indices
__random_samples = _random_samples
out = ms.ops.fractional_max_pool2d(input_ms, _kernel_size, _output_size, _output_ratio, _return_indices,
__random_samples)
return cast_to_adapter_tensor(out)

def fractional_max_pool3d(input_x, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
input_ms = cast_to_ms_tensor(input_x)
_kernel_size = kernel_size
_output_size = output_size
_output_ratio = output_ratio
_return_indices = return_indices
__random_samples = _random_samples
out = ms.ops.fractional_max_pool3d(input_ms, _kernel_size, _output_size, _output_ratio, _return_indices,
__random_samples)
return cast_to_adapter_tensor(out)

+ 2
- 0
ms_adapter/pytorch/nn/modules/__init__.py View File

@@ -52,6 +52,8 @@ __all__ = [
'LazyInstanceNorm2d',
'LazyInstanceNorm3d',

'FractionalMaxPool2d',
'FractionalMaxPool3d',
'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',


+ 53
- 1
ms_adapter/pytorch/nn/modules/pooling.py View File

@@ -14,7 +14,7 @@ __all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d',
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
'LPPool1d', 'LPPool2d']
'LPPool1d', 'LPPool2d', 'FractionalMaxPool2d', 'FractionalMaxPool3d']

class _MaxPoolNd(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
@@ -418,3 +418,55 @@ class LPPool2d(_LPPoolNd):
def forward(self, input):
return Adapter_F.lp_pool2d(input, float(self.norm_type), self.kernel_size,
self.stride, self.ceil_mode)

class FractionalMaxPool2d(Module):
def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
super(FractionalMaxPool2d, self).__init__()
self.kernel_size = kernel_size
self.return_indices = return_indices
self.output_size = output_size
self.output_ratio = output_ratio
self._random_samples = _random_samples
if output_size is None and output_ratio is None:
raise ValueError("FractionalMaxPool2d requires specifying either "
"an output size, or a pooling ratio")
if output_size is not None and output_ratio is not None:
raise ValueError("only one of output_size and output_ratio may be specified")
if self.output_ratio is not None:
if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio))

def forward(self, input):
@cast_tensor
def _call_ms_api(input):
return Adapter_F.fractional_max_pool2d(input, self.kernel_size, self.output_size, self.output_ratio,
self.return_indices, self._random_samples)
return _call_ms_api(input)

class FractionalMaxPool3d(Module):
def __init__(self, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None):
super(FractionalMaxPool3d, self).__init__()
self.kernel_size = kernel_size
self.return_indices = return_indices
self.output_size = output_size
self.output_ratio = output_ratio
self._random_samples = _random_samples
if output_size is None and output_ratio is None:
raise ValueError("FractionalMaxPool3d requires specifying either "
"an output size, or a pooling ratio")
if output_size is not None and output_ratio is not None:
raise ValueError("only one of output_size and output_ratio may be specified")
if self.output_ratio is not None:
if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
raise ValueError("output_ratio must be between 0 and 1 (got {})"
.format(output_ratio))

def forward(self, input):
@cast_tensor
def _call_ms_api(input):
return Adapter_F.fractional_max_pool3d(input, self.kernel_size, self.output_size, self.output_ratio,
self.return_indices, self._random_samples)
return _call_ms_api(input)

+ 85
- 1
testing/ut/pytorch/nn/test_pooling.py View File

@@ -4,7 +4,7 @@ import torch
from mindspore import Tensor
from ms_adapter.pytorch.nn import MaxPool1d, MaxPool2d, MaxPool3d, AvgPool1d, AvgPool2d, AvgPool3d, \
AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d, \
LPPool1d, LPPool2d
LPPool1d, LPPool2d, FractionalMaxPool2d, FractionalMaxPool3d

import mindspore as ms
ms.context.set_context(mode=ms.GRAPH_MODE)
@@ -308,6 +308,85 @@ def test_adaptive_avgpool3d_compare1():
assert (torch_output.shape == ms_output.shape)
# TODO: assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_fractional_maxpool2d_compare1():
ms_net = FractionalMaxPool2d(1, output_ratio=(0.5, 0.5))
torch_net = torch.nn.FractionalMaxPool2d(1, output_ratio=(0.5, 0.5))

data = np.array([0.3220, 0.9545, 0.7879, 0.0975, 0.3698,
0.5135, 0.5740, 0.3435, 0.1895, 0.8764,
0.9581, 0.4760, 0.9014, 0.8522, 0.3664,
0.4980, 0.9673, 0.9879, 0.6988, 0.9022,
0.9304, 0.1558, 0.0153, 0.1559, 0.9852]).reshape([1, 1, 5, 5])
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())

def test_fractional_maxpool2d_compare2():
_random_samples = np.array([[[0.8, 0.8]]])
ms_random_samples = Tensor(_random_samples.astype(np.float32))
torch_random_samples = torch.Tensor(_random_samples)
ms_net = FractionalMaxPool2d(kernel_size=1, output_size=(2, 2), _random_samples=ms_random_samples,
return_indices=True)
torch_net = torch.nn.FractionalMaxPool2d(kernel_size=1, output_size=(2, 2), _random_samples=torch_random_samples,
return_indices=True)

data = np.array([0.3220, 0.9545, 0.7879, 0.0975, 0.3698,
0.5135, 0.5740, 0.3435, 0.1895, 0.8764,
0.9581, 0.4760, 0.9014, 0.8522, 0.3664,
0.4980, 0.9673, 0.9879, 0.6988, 0.9022,
0.9304, 0.1558, 0.0153, 0.1559, 0.9852]).reshape([1, 1, 5, 5])

ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output, ms_indices= ms_net(ms_input)
torch_output, torch_indices = torch_net(torch_input)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert np.allclose(ms_indices.asnumpy(), torch_indices.numpy())

def test_fractional_maxpool3d_compare1():
_random_samples = np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3])
ms_random_samples = Tensor(_random_samples.astype(np.float32))
torch_random_samples = torch.Tensor(_random_samples)
ms_net = FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), output_size=(1, 1, 3), _random_samples=ms_random_samples,
return_indices=True)
torch_net = torch.nn.FractionalMaxPool3d(kernel_size=(1, 1, 1), output_size=(1, 1, 3), _random_samples=torch_random_samples,
return_indices=True)

data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).reshape([1, 1, 2, 2, 4])

ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output, ms_indices = ms_net(ms_input)
torch_output, torch_indices = torch_net(torch_input)

assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert np.allclose(ms_indices.asnumpy(), torch_indices.numpy())

def test_fractional_maxpool3d_compare2():
_random_samples = np.array([0.7, 0.7, 0.7]).reshape([1, 1, 3])
ms_random_samples = Tensor(_random_samples.astype(np.float32))
torch_random_samples = torch.Tensor(_random_samples)
ms_net = FractionalMaxPool3d(kernel_size=(1.0, 1.0, 1.0), output_ratio=(0.5, 0.5, 0.5), _random_samples=ms_random_samples,
return_indices=True)
torch_net = torch.nn.FractionalMaxPool3d(kernel_size=(1, 1, 1), output_ratio=(0.5, 0.5, 0.5), _random_samples=torch_random_samples,
return_indices=True)

data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).reshape([1, 1, 2, 2, 4])

ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output, ms_indices = ms_net(ms_input)
torch_output, torch_indices = torch_net(torch_input)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy())
assert np.allclose(ms_indices.asnumpy(), torch_indices.numpy())

def test_lppool1d_compare1():
ms_net = LPPool1d(norm_type=1, kernel_size=3, stride=1)
@@ -368,3 +447,8 @@ if __name__ == '__main__':

test_lppool1d_compare1()
test_lppool2d_compare1()

test_fractional_maxpool2d_compare1()
test_fractional_maxpool2d_compare2()
test_fractional_maxpool3d_compare1()
test_fractional_maxpool3d_compare2()

Loading…
Cancel
Save