#33 master

Merged
jiayu_neu merged 4 commits from OpenI/MSAdapter:master into master 1 year ago
  1. +1
    -1
      ms_adapter/pytorch/functional.py
  2. +5
    -3
      ms_adapter/pytorch/tensor.py
  3. +2
    -2
      ms_adapter/torchvision/models/__init__.py
  4. +5
    -0
      ms_adapter/torchvision/models/quantization/__init__.py
  5. +9
    -0
      ms_adapter/torchvision/models/quantization/googlenet.py
  6. +11
    -0
      ms_adapter/torchvision/models/quantization/inception.py
  7. +4
    -0
      ms_adapter/torchvision/models/quantization/mobilenet.py
  8. +8
    -0
      ms_adapter/torchvision/models/quantization/mobilenetv2.py
  9. +9
    -0
      ms_adapter/torchvision/models/quantization/mobilenetv3.py
  10. +16
    -0
      ms_adapter/torchvision/models/quantization/resnet.py
  11. +21
    -0
      ms_adapter/torchvision/models/quantization/shufflenetv2.py
  12. +3
    -0
      ms_adapter/torchvision/models/utils.py
  13. +1
    -0
      ms_adapter/torchvision/models/video/__init__.py
  14. +343
    -0
      ms_adapter/torchvision/models/video/resnet.py
  15. +3
    -2
      testing/ut/pytorch/tensor/test_tensor.py
  16. +46
    -1
      testing/ut/torchvision/models/test_models.py

+ 1
- 1
ms_adapter/pytorch/functional.py View File

@@ -1309,7 +1309,7 @@ def float_power(input, exponent, *, out=None):

def floor_divide(input, other, *, out=None):
# ms.ops.floor_divide doesn't round the quotient towards 0
# use ms.ops.div with rounding_mode trunc instead
# same behavior as torch version lower than 1.13
input = cast_to_ms_tensor(input)
other = cast_to_ms_tensor(other)
output = ms.ops.div(input, other, rounding_mode='trunc')


+ 5
- 3
ms_adapter/pytorch/tensor.py View File

@@ -3260,6 +3260,8 @@ class Tensor(ms.Tensor):
return cast_to_adapter_tensor(output)

def floor_divide(self, value):
# ms.ops.floor_divide doesn't round the quotient towards 0
# same behavior as torch version lower than 1.13
input = cast_to_ms_tensor(self)
value = cast_to_ms_tensor(value)
output = ms.ops.div(input, value, rounding_mode='trunc')
@@ -3371,10 +3373,10 @@ def cast_to_adapter_tensor(outputs):
if isinstance(outputs, (ms.Tensor, Tensor_)):
outputs = inner.convert_to_adapter_tensor(outputs)
elif isinstance(outputs, (tuple, list)):
outputs = list(outputs)
outputs_list = list(outputs)
for id, value in enumerate(outputs):
outputs[id] = _cast(value)
outputs = tuple(outputs)
outputs_list[id] = _cast(value)
outputs = tuple(outputs_list)
return outputs

outputs = _cast(outputs)


+ 2
- 2
ms_adapter/torchvision/models/__init__.py View File

@@ -13,5 +13,5 @@ from .mnasnet import *
from .shufflenetv2 import *
from . import segmentation
from . import detection
# from . import video
# from . import quantization
from . import video
from . import quantization

+ 5
- 0
ms_adapter/torchvision/models/quantization/__init__.py View File

@@ -0,0 +1,5 @@
from .mobilenet import *
from .resnet import *
from .googlenet import *
from .inception import *
from .shufflenetv2 import *

+ 9
- 0
ms_adapter/torchvision/models/quantization/googlenet.py View File

@@ -0,0 +1,9 @@
__all__ = ['QuantizableGoogLeNet', 'googlenet']

def googlenet():
raise NotImplementedError("Quantization model are not currently supported")

class QuantizableGoogLeNet(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")


+ 11
- 0
ms_adapter/torchvision/models/quantization/inception.py View File

@@ -0,0 +1,11 @@
__all__ = [
"QuantizableInception3",
"inception_v3",
]

def inception_v3():
raise NotImplementedError("Quantization model are not currently supported")

class QuantizableInception3(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")

+ 4
- 0
ms_adapter/torchvision/models/quantization/mobilenet.py View File

@@ -0,0 +1,4 @@
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, __all__ as mv3_all

__all__ = mv2_all + mv3_all

+ 8
- 0
ms_adapter/torchvision/models/quantization/mobilenetv2.py View File

@@ -0,0 +1,8 @@
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']

class QuantizableMobileNetV2(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")

def mobilenet_v2():
raise NotImplementedError("Quantization model are not currently supported")

+ 9
- 0
ms_adapter/torchvision/models/quantization/mobilenetv3.py View File

@@ -0,0 +1,9 @@
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large']


class QuantizableMobileNetV3(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")

def mobilenet_v3_large():
raise NotImplementedError("Quantization model are not currently supported")

+ 16
- 0
ms_adapter/torchvision/models/quantization/resnet.py View File

@@ -0,0 +1,16 @@
__all__ = ['QuantizableResNet', 'resnet18', 'resnet50',
'resnext101_32x8d']

class QuantizableResNet(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")

def resnet18():
raise NotImplementedError("Quantization model are not currently supported")

def resnet50():
raise NotImplementedError("Quantization model are not currently supported")

def resnext101_32x8d():
raise NotImplementedError("Quantization model are not currently supported")


+ 21
- 0
ms_adapter/torchvision/models/quantization/shufflenetv2.py View File

@@ -0,0 +1,21 @@
__all__ = [
'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]


class QuantizableShuffleNetV2(object):
def __init__(self):
raise NotImplementedError("Quantization model are not currently supported")

def shufflenet_v2_x0_5():
raise NotImplementedError("Quantization model are not currently supported")

def shufflenet_v2_x1_0():
raise NotImplementedError("Quantization model are not currently supported")

def shufflenet_v2_x1_5():
raise NotImplementedError("Quantization model are not currently supported")

def shufflenet_v2_x2_0():
raise NotImplementedError("Quantization model are not currently supported")

+ 3
- 0
ms_adapter/torchvision/models/utils.py View File

@@ -51,6 +51,9 @@ ckpt_file_link = {
'retinanet_resnet50_fpn_coco':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=retinanet_resnet50_fpn_co',
'ssd300_vgg16_coco':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=ssd300_vgg16_coco',
'vgg16_features':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=vgg16_features',
'r3d_18':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=r3d_18',
'mc3_18':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=mc3_18',
'r2plus1d_18':'https://openi.pcl.ac.cn/OpenI/MSAdapter/modelmanage/show_model_info?name=r2plus1d_18',
}




+ 1
- 0
ms_adapter/torchvision/models/video/__init__.py View File

@@ -0,0 +1 @@
from .resnet import *

+ 343
- 0
ms_adapter/torchvision/models/video/resnet.py View File

@@ -0,0 +1,343 @@
import ms_adapter.pytorch.nn as nn
from ms_adapter import unsupported_attr
from ..utils import check_ckpt_file
import ms_adapter.pytorch as torch


__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18']


model_urls = {
'r3d_18': 'r3d_18-b3b3357e.ckpt',
'mc3_18': 'mc3_18-a90a0ba3.ckpt',
'r2plus1d_18': 'r2plus1d_18-91a641e6.ckpt',
}


class Conv3DSimple(nn.Conv3d):
def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):

super(Conv3DSimple, self).__init__(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=(3, 3, 3),
stride=stride,
padding=padding,
bias=False)

@staticmethod
def get_downsample_stride(stride):
return stride, stride, stride


class Conv2Plus1D(nn.Sequential):

def __init__(self,
in_planes,
out_planes,
midplanes,
stride=1,
padding=1):
super(Conv2Plus1D, self).__init__(
nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
stride=(1, stride, stride), padding=(0, padding, padding),
bias=False),
nn.BatchNorm3d(midplanes),
nn.ReLU(inplace=True),
nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
stride=(stride, 1, 1), padding=(padding, 0, 0),
bias=False))

@staticmethod
def get_downsample_stride(stride):
return stride, stride, stride


class Conv3DNoTemporal(nn.Conv3d):

def __init__(self,
in_planes,
out_planes,
midplanes=None,
stride=1,
padding=1):

super(Conv3DNoTemporal, self).__init__(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=(1, 3, 3),
stride=(1, stride, stride),
padding=(0, padding, padding),
bias=False)

@staticmethod
def get_downsample_stride(stride):
return 1, stride, stride


class BasicBlock(nn.Module):

expansion = 1

def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

super(BasicBlock, self).__init__()
self.conv1 = nn.Sequential(
conv_builder(inplanes, planes, midplanes, stride),
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes),
nn.BatchNorm3d(planes)
)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):

super(Bottleneck, self).__init__()
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

# 1x1x1
self.conv1 = nn.Sequential(
nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
)
# Second kernel
self.conv2 = nn.Sequential(
conv_builder(planes, planes, midplanes, stride),
nn.BatchNorm3d(planes),
nn.ReLU(inplace=True)
)

# 1x1x1
self.conv3 = nn.Sequential(
nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
nn.BatchNorm3d(planes * self.expansion)
)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class BasicStem(nn.Sequential):
"""The default conv-batchnorm-relu stem
"""
def __init__(self):
super(BasicStem, self).__init__(
nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
padding=(1, 3, 3), bias=False),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))


class R2Plus1dStem(nn.Sequential):
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution
"""
def __init__(self):
super(R2Plus1dStem, self).__init__(
nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
stride=(1, 2, 2), padding=(0, 3, 3),
bias=False),
nn.BatchNorm3d(45),
nn.ReLU(inplace=True),
nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0),
bias=False),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))


class VideoResNet(nn.Module):

def __init__(self, block, conv_makers, layers,
stem, num_classes=400,
zero_init_residual=False):
"""Generic resnet video generator.

Args:
block (nn.Module): resnet building block
conv_makers (list(functions)): generator function for each layer
layers (List[int]): number of blocks per layer
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None.
num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
"""
super(VideoResNet, self).__init__()
self.inplanes = 64

self.stem = stem()

self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)

self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

# init weights
self._initialize_weights()

if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)

def forward(self, x):
x = self.stem(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool(x)
# Flatten the layer to fc
x = x.flatten(1)
x = self.fc(x)

return x

def _make_layer(self, block, conv_builder, planes, blocks, stride=1):
downsample = None

if stride != 1 or self.inplanes != planes * block.expansion:
ds_stride = conv_builder.get_downsample_stride(stride)
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=ds_stride, bias=False),
nn.BatchNorm3d(planes * block.expansion)
)
layers = []
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_builder))

return nn.Sequential(*layers)

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv3d):
nn.init.kaiming_normal_(m.weight, mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)


def _video_resnet(arch, pretrained=False, progress=True, **kwargs):
unsupported_attr(progress)
model = VideoResNet(**kwargs)

if pretrained:
check_ckpt_file(model_urls[arch])
state_dict = torch.load(model_urls[arch])
model.load_state_dict(state_dict)
return model


def r3d_18(pretrained=False, progress=True, **kwargs):
"""Construct 18 layer Resnet3D model as in
https://arxiv.org/abs/1711.11248

Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr

Returns:
nn.Module: R3D-18 network
"""

return _video_resnet('r3d_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] * 4,
layers=[2, 2, 2, 2],
stem=BasicStem, **kwargs)


def mc3_18(pretrained=False, progress=True, **kwargs):
"""Constructor for 18 layer Mixed Convolution network as in
https://arxiv.org/abs/1711.11248

Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr

Returns:
nn.Module: MC3 Network definition
"""
return _video_resnet('mc3_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3,
layers=[2, 2, 2, 2],
stem=BasicStem, **kwargs)


def r2plus1d_18(pretrained=False, progress=True, **kwargs):
"""Constructor for the 18 layer deep R(2+1)D network as in
https://arxiv.org/abs/1711.11248

Args:
pretrained (bool): If True, returns a model pre-trained on Kinetics-400
progress (bool): If True, displays a progress bar of the download to stderr

Returns:
nn.Module: R(2+1)D-18 network
"""
return _video_resnet('r2plus1d_18',
pretrained, progress,
block=BasicBlock,
conv_makers=[Conv2Plus1D] * 4,
layers=[2, 2, 2, 2],
stem=R2Plus1dStem, **kwargs)

+ 3
- 2
testing/ut/pytorch/tensor/test_tensor.py View File

@@ -4521,8 +4521,9 @@ def test_floor_divide():

assert np.allclose(torch_out1.numpy(), ms_out1.numpy())
assert torch_out1.numpy().dtype == ms_out1.numpy().dtype
assert np.allclose(torch_out2.numpy(), ms_out2.numpy())
assert torch_out2.numpy().dtype == ms_out2.numpy().dtype
if torch.__version__ < '1.13.0':
assert np.allclose(torch_out2.numpy(), ms_out2.numpy())
assert torch_out2.numpy().dtype == ms_out2.numpy().dtype

def test_floor_divide_():
a = np.array([4.0, 3.0]).astype(np.float64)


+ 46
- 1
testing/ut/torchvision/models/test_models.py View File

@@ -384,6 +384,48 @@ def test_ssd300_vgg16():
out_m = net_m([mtorch.rand(3, 300, 400), mtorch.rand(3, 500, 400)])
print(out_m)

def test_r3d_18():
from ms_adapter.torchvision.models.video import r3d_18 as r3d_18_ms
from torchvision.models.video import r3d_18 as r3d_18_t
import torch
import ms_adapter.pytorch as mtorch
net_t = r3d_18_t(pretrained=True)
net_t.eval()
out_t = net_t(torch.rand(64, 3, 32, 32, 32))
print(out_t)
net_m = r3d_18_ms(pretrained=True)
net_m.eval()
out_m = net_m(mtorch.rand(64, 3, 32, 32, 32))
print(out_m)

def test_r2plus1d_18():
from ms_adapter.torchvision.models.video import r2plus1d_18 as r2plus1d_18_ms
from torchvision.models.video import r2plus1d_18 as r2plus1d_18_t
import torch
import ms_adapter.pytorch as mtorch
net_t = r2plus1d_18_t(pretrained=True)
net_t.eval()
out_t = net_t(torch.rand(64, 3, 32, 32, 32))
print(out_t)
net_m = r2plus1d_18_ms(pretrained=True)
net_m.eval()
out_m = net_m(mtorch.rand(64, 3, 32, 32, 32))
print(out_m)

def test_mc3_18():
from ms_adapter.torchvision.models.video import mc3_18 as mc3_18_ms
from torchvision.models.video import mc3_18 as mc3_18_t
import torch
import ms_adapter.pytorch as mtorch
net_t = mc3_18_t(pretrained=True)
net_t.eval()
out_t = net_t(torch.rand(64, 3, 32, 32, 32))
print(out_t)
net_m = mc3_18_ms(pretrained=True)
net_m.eval()
out_m = net_m(mtorch.rand(64, 3, 32, 32, 32))
print(out_m)


if __name__ == '__main__':
test_alexnet()
@@ -434,4 +476,7 @@ if __name__ == '__main__':
test_retinanet_resnet50_fpn()
test_ssdlite320_mobilenet_v3_large()
test_ssd300_vgg16()
## End Maybe it needs to be tested on a GPU
# End Maybe it needs to be tested on a GPU
test_r3d_18()
test_r2plus1d_18()
test_mc3_18()

Loading…
Cancel
Save