#70 add nn.ReflectionPad1-2d/nn.ConstantPad1-3d/nn.ZeroPad2d

Merged
Manson merged 4 commits from zoulq/MSAdapter-zlq:master into master 1 year ago
  1. +7
    -1
      ms_adapter/pytorch/nn/modules/__init__.py
  2. +262
    -0
      ms_adapter/pytorch/nn/modules/padding.py
  3. +154
    -0
      testing/ut/pytorch/nn/test_padding.py

+ 7
- 1
ms_adapter/pytorch/nn/modules/__init__.py View File

@@ -8,7 +8,7 @@ from .conv import *
from .batchnorm import *
from .pooling import *
from .loss import *
from .padding import *
from .module import Module
from .container import Sequential, ModuleList
from .dropout import Dropout, Dropout2d, Dropout3d
@@ -55,4 +55,10 @@ __all__ = [
'SmoothL1Loss',
'LogSigmoid',
'ELU',
'ConstantPad1d',
'ConstantPad2d',
'ConstantPad3d',
'ReflectionPad1d',
'ReflectionPad2d',
'ZeroPad2d',
]

+ 262
- 0
ms_adapter/pytorch/nn/modules/padding.py View File

@@ -0,0 +1,262 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Sequence, Tuple

from mindspore import nn

from .module import Module
import ms_adapter.pytorch.nn.functional as F

__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d',
'ZeroPad2d']

class _ConstantPadNd(Module):
__constants__ = ['padding', 'value']

def __init__(self, padding, value: float):
super(_ConstantPadNd, self).__init__()
self.padding = padding
self.value = value
self.pad_fun = None
def forward(self, input):
output = self.pad_fun(input)
return output

def extra_repr(self) -> str:
return 'padding={}, value={}'.format(self.padding, self.value)

class ConstantPad1d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in both boundaries. If a 2-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::
>>> m = nn.ConstantPad1d(2, 3.5)
>>> input = ms_adapter.pytorch.ones(1, 2, 4)
>>> m(input)
"""

def __init__(self, padding, value: float):
super(ConstantPad1d, self).__init__(padding, value)
self.pad_fun = nn.ConstantPad1d(self.padding, self.value)


class ConstantPad2d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)

Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ConstantPad2d(2, 3.5)
>>> input = ms_adapter.pytorch.ones(1, 2, 2)
>>> m(input)

"""
def __init__(self, padding, value: float):
super(ConstantPad2d, self).__init__(padding, value)
self.pad_fun = nn.ConstantPad2d(self.padding, self.value)


class ConstantPad3d(_ConstantPadNd):
r"""Pads the input tensor boundaries with a constant value.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 6-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)

Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
:math:`(C, D_{out}, H_{out}, W_{out})`, where

:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ConstantPad3d(3, 3.5)
>>> input = ms_adapter.pytorch.ones(16, 3, 10, 20, 30)
>>> output = m(input)

"""
def __init__(self, padding, value: float):
super(ConstantPad3d, self).__init__(padding, value)
self.pad_fun = nn.ConstantPad3d(self.padding, self.value)


class _ReflectionPadNd(Module):
__constants__ = ['padding']

def __init__(self, padding):
super(_ReflectionPadNd, self).__init__()
self.padding = padding
self.pad_fun = None
def forward(self, input):
output = self.pad_fun(input)
return output

def extra_repr(self) -> str:
return '{}'.format(self.padding)


class ReflectionPad1d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 2-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)

Shape:
- Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
- Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ReflectionPad1d(2)
>>> input = ms_adapter.pytorch.ones(1, 2, 4)
>>> m(input)

"""

def __init__(self, padding):
super(ReflectionPad1d, self).__init__(padding)
self.pad_fun = nn.ReflectionPad1d(self.padding)


class ReflectionPad2d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)

Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})` where

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ReflectionPad2d(2)
>>> input = ms_adapter.pytorch.ones(1, 1, 3, 3)
>>> m(input)

"""

def __init__(self, padding):
super(ReflectionPad2d, self).__init__(padding)
self.pad_fun = nn.ReflectionPad2d(self.padding)


class ReflectionPad3d(_ReflectionPadNd):
r"""Pads the input tensor using the reflection of the input boundary.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 6-`tuple`, uses
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`,
:math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`,
:math:`\text{padding\_front}`, :math:`\text{padding\_back}`)

Shape:
- Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
where

:math:`D_{out} = D_{in} + \text{padding\_front} + \text{padding\_back}`

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ReflectionPad3d(1)
>>> input = ms_adapter.pytorch.ones(1, 1, 2, 2, 2)
>>> m(input)

"""

def __init__(self, padding):
super(ReflectionPad3d, self).__init__(padding)
# TODO: mindspore don't has nn.ReflectionPad3d API now.
# self.pad_fun = nn.ReflectionPad3d(self.padding)


class ZeroPad2d(ConstantPad2d):
r"""Pads the input tensor boundaries with zero.

For `N`-dimensional padding, use :func:`torch.nn.functional.pad()`.

Args:
padding (int, tuple): the size of the padding. If is `int`, uses the same
padding in all boundaries. If a 4-`tuple`, uses (:math:`\text{padding\_left}`,
:math:`\text{padding\_right}`, :math:`\text{padding\_top}`, :math:`\text{padding\_bottom}`)

Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where

:math:`H_{out} = H_{in} + \text{padding\_top} + \text{padding\_bottom}`

:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`

Examples::

>>> m = nn.ZeroPad2d(2)
>>> input = ms_adapter.pytorch.ones(1, 1, 3, 3)
>>> m(input)

"""

def __init__(self, padding) -> None:
super(ZeroPad2d, self).__init__(padding, 0.)

+ 154
- 0
testing/ut/pytorch/nn/test_padding.py View File

@@ -0,0 +1,154 @@
import numpy as np
import torch
import mindspore as ms
from mindspore import context

import ms_adapter.pytorch as ms_pytorch

context.set_context(mode=ms.PYNATIVE_MODE)

def test_constant_pad_1d():
padding = 3
value = 3.5
pt_input_2d = torch.ones(2, 3)
pt_pad_fun1 = torch.nn.ConstantPad1d(padding, value)
pt_pad_out1 = pt_pad_fun1(pt_input_2d)
ms_input_2d = ms_pytorch.ones(2, 3)
ms_pad_fun1 = ms_pytorch.nn.ConstantPad1d(padding, value)
ms_pad_out1 = ms_pad_fun1(ms_input_2d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (3, 1)
value = 2.0
pt_input_3d = torch.ones(2, 3, 4)
pt_pad_fun2 = torch.nn.ConstantPad1d(padding, value)
pt_pad_out2 = pt_pad_fun2(pt_input_3d)
ms_input_3d = ms_pytorch.ones(2, 3, 4)
ms_pad_fun2 = ms_pytorch.nn.ConstantPad1d(padding, value)
ms_pad_out2 = ms_pad_fun2(ms_input_3d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_constant_pad_2d():
padding = 2
value = 2.5
pt_input_3d = torch.ones(2, 3, 4)
pt_pad_fun1 = torch.nn.ConstantPad2d(padding, value)
pt_pad_out1 = pt_pad_fun1(pt_input_3d)
ms_input_3d = ms_pytorch.ones(2, 3, 4)
ms_pad_fun1 = ms_pytorch.nn.ConstantPad2d(padding, value)
ms_pad_out1 = ms_pad_fun1(ms_input_3d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (-1, 1, 0, 1)
value = 2.0
pt_input_4d = torch.ones(1, 2, 3, 4)
pt_pad_fun2 = torch.nn.ConstantPad2d(padding, value)
pt_pad_out2 = pt_pad_fun2(pt_input_4d)
ms_input_4d = ms_pytorch.ones(1, 2, 3, 4)
ms_pad_fun2 = ms_pytorch.nn.ConstantPad2d(padding, value)
ms_pad_out2 = ms_pad_fun2(ms_input_4d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_constant_pad_3d():
padding = 1
value = 2.5
pt_input_4d = torch.ones(2, 1, 3, 4)
pt_pad_fun1 = torch.nn.ConstantPad3d(padding, value)
pt_pad_out1 = pt_pad_fun1(pt_input_4d)
ms_input_4d = ms_pytorch.ones(2, 1, 3, 4)
ms_pad_fun1 = ms_pytorch.nn.ConstantPad3d(padding, value)
ms_pad_out1 = ms_pad_fun1(ms_input_4d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (3, 3, 6, 6, 0, 1)
value = 2.0
pt_input_5d = torch.ones(1, 2, 1, 3, 4)
pt_pad_fun2 = torch.nn.ConstantPad3d(padding, value)
pt_pad_out2 = pt_pad_fun2(pt_input_5d)
ms_input_5d = ms_pytorch.ones(1, 2, 1, 3, 4)
ms_pad_fun2 = ms_pytorch.nn.ConstantPad3d(padding, value)
ms_pad_out2 = ms_pad_fun2(ms_input_5d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_reflection_pad_1d():
padding = 3
pt_input_2d = torch.ones(2, 6)
pt_pad_fun1 = torch.nn.ReflectionPad1d(padding)
pt_pad_out1 = pt_pad_fun1(pt_input_2d)
ms_input_2d = ms_pytorch.ones(2, 6)
ms_pad_fun1 = ms_pytorch.nn.ReflectionPad1d(padding)
ms_pad_out1 = ms_pad_fun1(ms_input_2d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (2, 1)
pt_input_3d = torch.ones(2, 3, 4)
pt_pad_fun2 = torch.nn.ReflectionPad1d(padding)
pt_pad_out2 = pt_pad_fun2(pt_input_3d)
ms_input_3d = ms_pytorch.ones(2, 3, 4)
ms_pad_fun2 = ms_pytorch.nn.ReflectionPad1d(padding)
ms_pad_out2 = ms_pad_fun2(ms_input_3d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_reflection_pad_2d():
padding = 2
pt_input_3d = torch.ones(2, 3, 3)
pt_pad_fun1 = torch.nn.ReflectionPad2d(padding)
pt_pad_out1 = pt_pad_fun1(pt_input_3d)
ms_input_3d = ms_pytorch.ones(2, 3, 3)
ms_pad_fun1 = ms_pytorch.nn.ReflectionPad2d(padding)
ms_pad_out1 = ms_pad_fun1(ms_input_3d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (1, 1, 2, 0)
pt_input_4d = torch.ones(1, 2, 3, 4)
pt_pad_fun2 = torch.nn.ReflectionPad2d(padding)
pt_pad_out2 = pt_pad_fun2(pt_input_4d)
ms_input_4d = ms_pytorch.ones(1, 2, 3, 4)
ms_pad_fun2 = ms_pytorch.nn.ReflectionPad2d(padding)
ms_pad_out2 = ms_pad_fun2(ms_input_4d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_zero_pad_2d():
padding = 2
pt_input_3d = torch.ones(1, 3, 3)
pt_pad_fun1 = torch.nn.ZeroPad2d(padding)
pt_pad_out1 = pt_pad_fun1(pt_input_3d)
ms_input_3d = ms_pytorch.ones(1, 3, 3)
ms_pad_fun1 = ms_pytorch.nn.ZeroPad2d(padding)
ms_pad_out1 = ms_pad_fun1(ms_input_3d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (1, 1, 2, 0)
pt_input_4d = torch.ones(1, 1, 3, 3)
pt_pad_fun2 = torch.nn.ZeroPad2d(padding)
pt_pad_out2 = pt_pad_fun2(pt_input_4d)
ms_input_4d = ms_pytorch.ones(1, 1, 3, 3)
ms_pad_fun2 = ms_pytorch.nn.ZeroPad2d(padding)
ms_pad_out2 = ms_pad_fun2(ms_input_4d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())

if __name__ == '__main__':
test_constant_pad_1d()
test_constant_pad_2d()
test_constant_pad_3d()
test_reflection_pad_1d()
test_reflection_pad_2d()
test_zero_pad_2d()


Loading…
Cancel
Save