#48 add conv2d test

Merged
Manson merged 1 commits from Manson/MSAdapter:conv_test into master 1 year ago
  1. +91
    -48
      testing/layers/test_conv.py

+ 91
- 48
testing/layers/test_conv.py View File

@@ -11,51 +11,94 @@ import mindspore.nn as nn
import torch
context.set_context(mode=ms.PYNATIVE_MODE)

class ConvModel(torch.nn.Module):
def __init__(self):
super(ConvModel, self).__init__()
super(ConvModel, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, padding=2, bias=False)
self.conv3 = torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1, padding=6)
self.conv4 = torch.nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding='same')

def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x


net = ConvModel()
inputs = torch.tensor(np.ones(shape=(1, 3, 32, 32)), dtype=torch.float32)
output = net(inputs)
print(output.shape)

class ConvModelms(Module):
def __init__(self):
super(ConvModelms, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=1, padding=2, bias=False)
self.conv3 = Conv2d(in_channels=32, out_channels=16, kernel_size=1, padding=6)
self.conv4 = Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding='same')

def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x

model = ConvModelms()
model.train()

for n, v in model.named_parameters():
print(n, v.shape)



inputs = tensor(np.ones(shape=(1, 3, 32, 32)), ms.float32)
output = model(inputs)
print(output.shape)

def test_torch_ms_module_compare():
"""Test torch and ms cell output shape."""
class ConvModel(torch.nn.Module):
def __init__(self):
super(ConvModel, self).__init__()
super(ConvModel, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, padding=2, bias=False)
self.conv3 = torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=1, padding=6)
self.conv4 = torch.nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding='same')

def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x

class ConvModelms(Module):
def __init__(self):
super(ConvModelms, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=16, kernel_size=3)
self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=1, padding=2, bias=False)
self.conv3 = Conv2d(in_channels=32, out_channels=16, kernel_size=1, padding=6)
self.conv4 = Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding='same')

def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x

py_net = ConvModel()
py_input = torch.tensor(np.ones(shape=(1, 3, 32, 32)), dtype=torch.float32)
py_output = py_net(py_input)

ms_net = ConvModelms()
ms_net.train()
ms_input = tensor(np.ones(shape=(1, 3, 32, 32)), ms.float32)
ms_output = ms_net(ms_input)
assert (py_output.shape == ms_output.shape)


def test_torch_ms_compare():
# padding = 'same'
py_input1 = torch.ones(1, 1, 9, 20)
ms_input1 = tensor(np.ones(shape=(1, 1, 9, 20)), ms.float32)
py_net1 = torch.nn.Conv2d(1, 1, kernel_size=(5, 10), padding='same')
ms_net1 = Conv2d(1, 1, kernel_size=(5, 10), padding='same')
py_output1 = py_net1(py_input1)
ms_output1 = ms_net1(ms_input1)
# ouput shape [1, 1, 9, 20]
assert(py_output1.shape == ms_output1.shape)

# padding = 'same', dilation = (1, 2)
py_input2 = torch.ones(1, 3, 16, 50)
ms_input2 = tensor(np.ones(shape=(1, 3, 16, 50)), ms.float32)
py_net2 = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 5),
padding='same', dilation=(1, 2))
ms_net2 = Conv2d(3, 64, kernel_size=(3, 5), padding='same', dilation=(1, 2))
py_output2 = py_net2(py_input2)
ms_output2 = ms_net2(ms_input2)
# ouput shape [1, 64, 16, 50]
assert(py_output2.shape == ms_output2.shape)

# padding = 'valid', dilation = (1, 2)
py_input3 = torch.ones(1, 3, 16, 50)
ms_input3 = tensor(np.ones(shape=(1, 3, 16, 50)), ms.float32)
py_net3 = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 5),
padding='valid', dilation=(1, 2))
ms_net3 = Conv2d(3, 64, kernel_size=(3, 5), padding='valid', dilation=(1, 2))
py_output3 = py_net3(py_input3)
ms_output3 = ms_net3(ms_input3)
# ouput shape [1, 64, 14, 42]
assert(py_output3.shape == ms_output3.shape)

py_input4 = torch.ones(1, 3, 16, 50)
ms_input4 = tensor(np.ones(shape=(1, 3, 16, 50)), ms.float32)
py_net4 = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 5),
padding='valid')
ms_net4 = Conv2d(3, 64, kernel_size=(3, 5), padding='valid')
py_output4 = py_net4(py_input4)
ms_output4 = ms_net4(ms_input4)
print(py_output4.shape)
# ouput shape [1, 64, 14, 46]
assert(py_output4.shape == ms_output4.shape)

test_torch_ms_module_compare()
test_torch_ms_compare()

Loading…
Cancel
Save