#292 add some funcs

Merged
zoulq merged 92 commits from tanh01 into master 1 year ago
  1. +10
    -7
      ms_adapter/pytorch/functional.py
  2. +211
    -1
      ms_adapter/pytorch/nn/functional.py
  3. +19
    -1
      ms_adapter/pytorch/nn/modules/__init__.py
  4. +26
    -1
      ms_adapter/pytorch/nn/modules/activation.py
  5. +193
    -0
      ms_adapter/pytorch/nn/modules/adaptive.py
  6. +23
    -0
      ms_adapter/pytorch/nn/modules/channelshuffle.py
  7. +43
    -0
      ms_adapter/pytorch/nn/modules/fold.py
  8. +110
    -1
      ms_adapter/pytorch/nn/modules/loss.py
  9. +9
    -5
      ms_adapter/pytorch/nn/modules/padding.py
  10. +0
    -0
      ms_adapter/pytorch/nn/modules/transformer.py
  11. +21
    -6
      testing/ut/pytorch/functional/test_math.py
  12. +111
    -0
      testing/ut/pytorch/nn/functional/test_conv.py
  13. +25
    -1
      testing/ut/pytorch/nn/functional/test_linear.py
  14. +174
    -0
      testing/ut/pytorch/nn/functional/test_loss.py
  15. +35
    -0
      testing/ut/pytorch/nn/test_activation.py
  16. +63
    -0
      testing/ut/pytorch/nn/test_adaptive.py
  17. +38
    -0
      testing/ut/pytorch/nn/test_channelshuffle.py
  18. +54
    -0
      testing/ut/pytorch/nn/test_fold.py
  19. +180
    -4
      testing/ut/pytorch/nn/test_loss.py
  20. +23
    -0
      testing/ut/pytorch/nn/test_padding.py
  21. +0
    -0
      testing/ut/pytorch/nn/test_transformer.py

+ 10
- 7
ms_adapter/pytorch/functional.py View File

@@ -1349,9 +1349,6 @@ def devide(input, other, *, rounding_mode=None, out=None):
_out_limit_pynative(out, "devide")
return div(input, other, rounding_mode=rounding_mode, out=out)

#Todo: not found class Digamma
#def digamma(input, *, out=None):


def erf(input, *, out=None):
input = cast_to_ms_tensor(input)
@@ -1486,10 +1483,6 @@ def lerp(input, end, weight, *, out=None):
return _out_inplace_assign(out, output, "lerp")


#Todo
#def lgamma(input, *, out=None):


def logaddexp(input, other, *, out=None):
input = cast_to_ms_tensor(input)
other = cast_to_ms_tensor(other)
@@ -2351,3 +2344,13 @@ def isreal(input, *, out=None):
input = cast_to_ms_tensor(input)
output = ms.ops.isreal(input)
return _out_inplace_assign(out, output, "isreal")

def lgamma(input, *, out=None):
input = cast_to_ms_tensor(input)
output = ms.ops.lgamma(input)
return _out_inplace_assign(out, output, "lgamma")

def digamma(input, *, out=None):
input = cast_to_ms_tensor(input)
output = ms.ops.digamma(input)
return _out_inplace_assign(out, output, "digamma")

+ 211
- 1
ms_adapter/pytorch/nn/functional.py View File

@@ -51,16 +51,28 @@ all = [
'threshold_',
'hardshrink',

'conv1d',
'conv2d',
'conv3d',

'normalize',
'local_response_norm',

'l1_loss',
'cross_entropy',
'ctc_loss',
'gaussian_nll_loss',
'hinge_embedding_loss',
'margin_ranking_loss',
'multilabel_margin_loss',
'multilabel_soft_margin_loss',
'nll_loss',
'kl_div',
'binary_cross_entropy',
'binary_cross_entropy_with_logits',
'upsample_nearest',
'poisson_nll_loss',
'triplet_margin_with_distance_loss',

'pairwise_distance',
'cosine_similarity',
@@ -82,6 +94,9 @@ all = [

'embedding',
'max_pool2d',

'fold',
'unfold'
]

@constexpr
@@ -570,6 +585,68 @@ def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-1
result = ms.ops.cross_entropy(input, target, weight, ignore_index, reduction, label_smoothing)
return cast_to_adapter_tensor(result)

def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False):
log_probs = cast_to_ms_tensor(log_probs)
targets = cast_to_ms_tensor(targets)
#TODO: length do not support tuple
if not isinstance(input_lengths, Tensor) or not isinstance(target_lengths, Tensor):
raise TypeError("'input_lengths' and 'target_lengths' only support Tensor now")
if isinstance(input_lengths, Tensor) and isinstance(target_lengths, Tensor):
input_lengths = cast_to_ms_tensor(input_lengths)
target_lengths = cast_to_ms_tensor(target_lengths)

if targets.dtype not in {ms.int32, ms.int64} \
or not (targets.dtype == input_lengths.dtype and targets.dtype == target_lengths.dtype):
targets = targets.astype(ms.int64)
input_lengths = input_lengths.astype(ms.int64)
target_lengths = target_lengths.astype(ms.int64)
result, _ = ms.ops.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)
return cast_to_adapter_tensor(result)

def gaussian_nll_loss(input, target, var, full=False, eps=1e-06, reduction='mean'):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
var = cast_to_ms_tensor(var)
rlt = ms.ops.gaussian_nll_loss(input, target, var, full, eps, reduction)
return cast_to_adapter_tensor(rlt)

def hinge_embedding_loss(input, target, margin=1.0, size_average=None, reduce=None, reduction='mean'):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)
rlt = ms.ops.hinge_embedding_loss(input, target, float(margin), reduction)
return cast_to_adapter_tensor(rlt)

def margin_ranking_loss(input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean'):
input1 = cast_to_ms_tensor(input1)
input2 = cast_to_ms_tensor(input2)
target = cast_to_ms_tensor(target)
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)
rlt = ms.ops.margin_ranking_loss(input1, input2, target, float(margin), reduction)
return cast_to_adapter_tensor(rlt)

def multilabel_margin_loss(input, target, size_average=None, reduce=None, reduction='mean'):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
if target.dtype != ms.int32:
target = target.astype(ms.int32)
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)
rlt = ms.ops.MultilabelMarginLoss(reduction)(input, target)
zoulq commented 1 year ago
Review
这个接口注释掉是因为什么功能不支持?
kcl commented 1 year ago
Review
cpu不支持,本地没测
zoulq commented 1 year ago
Review
mindspore.ops.multi_margin_loss---是否可以用这个接口,已测试该接口也可在GPU上支持
kcl commented 1 year ago
Review
看错了,这个是重复开发,已经删了。是MultiLabelMarginLoss不支持CPU
return cast_to_adapter_tensor(rlt)

def multilabel_soft_margin_loss(input, target, weight=None, size_average=None, reduce=None, reduction='mean'):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
if isinstance(weight, Tensor):
weight = cast_to_ms_tensor(weight)
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)
rlt = ms.ops.multilabel_soft_margin_loss(input, target, weight, reduction)
return cast_to_adapter_tensor(rlt)

def nll_loss(input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction="mean"):
"""
@@ -1022,7 +1099,7 @@ def multi_margin_loss(
input,
target,
p=1,
margin=1.0,
margin=1,
weight=None,
size_average=None,
reduce=None,
@@ -1516,6 +1593,20 @@ def linear(input, weight, bias=None):
output = cast_to_adapter_tensor(output)
return output

def bilinear(input1, input2, weight, bias=None):
input1 = cast_to_ms_tensor(input1)
input2 = cast_to_ms_tensor(input2)
weight = cast_to_ms_tensor(weight)
x = ms.ops.matmul(input1.reshape(-1, input1.shape[-1]),
weight.permute(1, 0, 2).reshape(weight.shape[1], -1))
x = ms.ops.mul(x, ms.ops.tile(input2.reshape(-1, input2.shape[-1]), (1, weight.shape[0])))
x = x.reshape(x.shape[0], weight.shape[0], -1)
x = ms.ops.reduce_sum(x, -1)
if bias is not None:
bias = cast_to_ms_tensor(bias)
x = ms.ops.bias_add(x, bias)
output = x.reshape(*input1.shape[:-1], -1)
return cast_to_adapter_tensor(output)

def lp_pool1d(input, norm_type, kernel_size, stride = None, ceil_mode = False):
input = cast_to_ms_tensor(input)
@@ -1935,3 +2026,122 @@ def prelu(input, weight):
weight = cast_to_ms_tensor(weight)
output = ms.ops.prelu(input, weight)
return cast_to_adapter_tensor(output)


def poisson_nll_loss(input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None,
reduction='mean'):
input_ms = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
pi = 3.141592653589793
if reduce is not None or size_average is not None:
reduction = _get_reduce_string(size_average, reduce)
if reduction not in ('none', 'mean', 'sum'):
raise ValueError(reduction + " is not valid")

if log_input:
ret = ms.ops.exp(input) - target * input
else:
ret = input_ms - target * ms.ops.log(input_ms + eps)
if full:
cond = ms.ops.gt(target, 1)
out = target * ms.ops.log(target) - target + 0.5 * ms.ops.log(2 * pi * target)
out = ms.ops.select(cond, out, ms.ops.zeros_like(input_ms))
ret = ret + out
if reduction == "mean":
ret = ms.ops.mean(ret)
elif reduction == "sum":
ret = ms.ops.sum(ret)
return cast_to_adapter_tensor(ret)


def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None, margin=1.0,
swap=False, reduction='mean'):
distance_function = distance_function if distance_function is not None else pairwise_distance

anchor = cast_to_ms_tensor(anchor)
positive = cast_to_ms_tensor(positive)
negative = cast_to_ms_tensor(negative)
positive_dist = distance_function(anchor, positive)
negative_dist = distance_function(anchor, negative)

if swap:
swap_dist = distance_function(positive, negative)
negative_dist = ms.ops.minimum(negative_dist, swap_dist)

output = ms.ops.clamp(positive_dist - negative_dist + margin, min=0.0)

if reduction == "mean":
ret = output.mean()
elif reduction == "sum":
ret = output.sum()
else:
ret = output
return cast_to_adapter_tensor(ret)


def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
# Todo: not support float64, change to float32 now
input_ms = cast_to_ms_tensor(input)
weight_ms = cast_to_ms_tensor(weight)
is_float64 = False
if input_ms.dtype in (ms.float64, ms.double):
input_ms = input_ms.astype(ms.float32)
weight_ms = weight_ms.astype(ms.float32)
is_float64 = True
if isinstance(stride, int):
stride = (stride, stride, stride)
elif len(stride)==1:
stride = (stride[0], stride[0], stride[0])
pad_mode = "pad"
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
elif isinstance(padding, tuple):
if len(padding)==1:
padding = (padding[0], padding[0], padding[0], padding[0], padding[0], padding[0])
else:
padding = (padding[0], padding[0], padding[1], padding[1], padding[2], padding[2])
else:
pad_mode = padding
padding = 0
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
elif len(dilation) == 1:
dilation = (dilation[0], dilation[0], dilation[0])

output = ms.ops.conv3d(input_ms, weight_ms, pad_mode, padding, stride, dilation, groups)
if bias is not None:
# TODO: ms.ops.biasadd also not support float64
if bias.dtype != output.dtype:
bias = bias.astype(output.dtype)
output = ms.ops.bias_add(output, bias)

if is_float64:
output = output.astype(ms.float64)

return cast_to_adapter_tensor(output)


zoulq commented 1 year ago
Review
ms.ops.unfold当前资料显示不支持GPU平台,可以标注一下todo
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
# TODO: do not support on GPU
input_ms = cast_to_ms_tensor(input)
output = ms.ops.unfold(input_ms, kernel_size, dilation, padding, stride)
output = output.reshape(output.shape[0], output.shape[1], -1)
return cast_to_adapter_tensor(output)


def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
# TODO: do not support on Ascend
input_ms = cast_to_ms_tensor(input)
ndim = input_ms.ndim
if ndim == 2:
input_ms = input_ms.expand_dims(0)
shape = input_ms.shape
if isinstance(kernel_size, int):
zoulq commented 1 year ago
Review
ms.ops.fold目前不支持Ascend平台,需要标注一下
shape_tmp = kernel_size * kernel_size
else:
shape_tmp = kernel_size[0] * kernel_size[1]
input_ms = input_ms.reshape(shape[0], -1, shape_tmp, shape[2])
output = ms.ops.fold(input_ms, ms.Tensor(output_size), kernel_size, dilation, padding, stride)
if ndim == 2:
output = output.squeeze(0)
return cast_to_adapter_tensor(output)

+ 19
- 1
ms_adapter/pytorch/nn/modules/__init__.py View File

@@ -19,6 +19,9 @@ from .dropout import Dropout, Dropout1d, Dropout2d, Dropout3d, AlphaDropout, Fea
from .upsampling import *
from .normalization import *
from .pixel_shuffle import *
from .channelshuffle import *
from .fold import *
from .adaptive import AdaptiveLogSoftmaxWithLoss

__all__ = [
'Linear',
@@ -39,6 +42,8 @@ __all__ = [
'LazyConvTranspose1d',
'LazyConvTranspose2d',
'LazyConvTranspose3d',
'Fold',
'Unfold',

'BatchNorm1d',
'BatchNorm2d',
@@ -93,12 +98,14 @@ __all__ = [
'Tanh',
'Tanhshrink',
'Threshold',
'Softplus',
'Softsign',
'Softmax',
'LogSoftmax',
'Softmax2d',
'Softmin',
'GLU',
'AdaptiveLogSoftmaxWithLoss',

'MultiheadAttention',
'Hardsigmoid',
@@ -132,6 +139,14 @@ __all__ = [
'CosineEmbeddingLoss',
'MultiMarginLoss',
'TripletMarginLoss',
'PoissonNLLLoss',
'GaussianNLLLoss',
'HingeEmbeddingLoss',
'MultiLabelMarginLoss',
'MultiLabelSoftMarginLoss',
'TripletMarginWithDistanceLoss',
'MarginRankingLoss',
'CTCLoss',

'LogSigmoid',
'ELU',
@@ -140,6 +155,7 @@ __all__ = [
'ConstantPad3d',
'ReflectionPad1d',
'ReflectionPad2d',
'ReflectionPad3d',
'ZeroPad2d',
'ReplicationPad1d',
'ReplicationPad2d',
@@ -166,4 +182,6 @@ __all__ = [

'PixelShuffle',
'PixelUnshuffle',

'ChannelShuffle'
]

+ 26
- 1
ms_adapter/pytorch/nn/modules/activation.py View File

@@ -18,7 +18,8 @@ from .module import Module

__all__ = ['ReLU', 'Hardtanh', 'ReLU6', 'SiLU', 'Hardswish', 'LeakyReLU', 'Sigmoid', 'LogSigmoid', 'ELU', 'RReLU',
'SELU', 'CELU', 'GELU', 'Mish', 'Softshrink', 'Tanh', 'Tanhshrink','Threshold', 'Softmax', 'LogSoftmax',
'Softmin', 'Softsign', 'GLU', 'Hardshrink', 'MultiheadAttention', 'Hardsigmoid', 'PReLU']
'Softmin', 'Softsign', 'GLU', 'Hardshrink', 'MultiheadAttention', 'Hardsigmoid', 'PReLU', 'Softplus',
'Softmax2d']


class ReLU(Module):
@@ -565,3 +566,27 @@ class PReLU(Module):

def extra_repr(self) -> str:
return 'num_parameters={}'.format(self.num_parameters)


class Softplus(Module):
def __init__(self, beta=1, threshold=20):
super(Softplus, self).__init__()
self.beta = beta
self.threshold = threshold

def forward(self, input):
return ms_torch_nn_func.softplus(input, self.beta, self.threshold)

def extra_repr(self):
return 'beta={}, threshold={}'.format(self.beta, self.threshold)


class Softmax2d(Module):
def __init__(self):
super(Softmax2d, self).__init__()
self.softmax2d = ms.nn.Softmax2d()

def forward(self, input):
if input.dim() not in {3, 4}:
raise RuntimeError("Softmax2d requires a 3D or 4D tensor as input")
return self.softmax2d(input)

+ 193
- 0
ms_adapter/pytorch/nn/modules/adaptive.py View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import namedtuple

import mindspore as ms
from ms_adapter.pytorch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor
from ms_adapter.utils import unsupported_attr
from .container import Sequential, ModuleList
from .linear import Linear
from .module import Module
from ..functional import log_softmax

_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])

class AdaptiveLogSoftmaxWithLoss(Module):
def __init__(self, in_features, n_classes, cutoffs, div_value=4., head_bias=False, device=None, dtype=None):
super(AdaptiveLogSoftmaxWithLoss, self).__init__()
unsupported_attr(device)
cutoffs = list(cutoffs)
# #TODO: pylint
# if (cutoffs != sorted(cutoffs)) \
# or (min(cutoffs) <= 0) \
# or (max(cutoffs) > (n_classes - 1)) \
# or (len(set(cutoffs)) != len(cutoffs)) \
# or any([int(c) != c for c in cutoffs]):
#
# raise ValueError("cutoffs should be a sequence of unique, positive "
# "integers sorted in an increasing order, where "
# "each value is between 1 and n_classes-1")

self.in_features = in_features
self.n_classes = n_classes
self.cutoffs = cutoffs + [n_classes]
self.div_value = div_value
self.head_bias = head_bias
self.dtype = dtype

self.shortlist_size = self.cutoffs[0]
self.n_clusters = len(self.cutoffs) - 1
self.head_size = self.shortlist_size + self.n_clusters

self.head = Linear(self.in_features, self.head_size, bias=self.head_bias, dtype=self.dtype)
self.tail = ModuleList()

for i in range(self.n_clusters):

hsz = int(self.in_features // (self.div_value ** (i + 1)))
osz = self.cutoffs[i + 1] - self.cutoffs[i]

projection = Sequential(
Linear(self.in_features, hsz, bias=False, dtype=self.dtype),
Linear(hsz, osz, bias=False, dtype=self.dtype),
)

self.tail.append(projection)

def reset_parameters(self):
self.head.reset_parameters()
for i2h, h2o in self.tail:
i2h.reset_parameters()
h2o.reset_parameters()

def forward(self, input_, target_):
input_ = cast_to_ms_tensor(input_)
#target_ = cast_to_ms_tensor(target_)
targ_dim = target_.dim()

if targ_dim == 1:
if input_.shape[0] != target_.shape[0]:
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')
if input_.dim() != 2:
raise RuntimeError('1D target tensor expects 2D input tensors, '
'but found inputs with size', input_.shape())
elif targ_dim == 0:
if input_.dim() != 1:
raise RuntimeError('0D target tensor expects 1D input tensors, '
'but found inputs with size', input_.shape())
else:
raise RuntimeError('0D or 1D target tensor expected, '
'multi-target not supported')

is_batched = targ_dim > 0
input = input_ if is_batched else input_.unsqueeze(0)
target = target_ if is_batched else target_.unsqueeze(0)

used_rows = 0
batch_size = target.shape[0]

output = input.new_zeros(batch_size)
#gather_inds = ms.numpy.empty(batch_size, target.dtype)
gather_inds = target.new_empty(batch_size)

cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):

low_idx = cutoff_values[i]
high_idx = cutoff_values[i + 1]

target_mask = (target >= low_idx) & (target < high_idx)
row_indices = target_mask.nonzero().squeeze()

if row_indices.numel() == 0:
continue

if i == 0:
#gather_inds.index_copy_(0, row_indices, target[target_mask])
gather_inds = index_copy_0dim(gather_inds, row_indices, target[target_mask])

else:
relative_target = target[target_mask] - low_idx
#input_subset = input.index_select(0, row_indices)
input_subset = ms.ops.gather(input, row_indices, 0)

cluster_output = self.tail[i - 1](input_subset)
cluster_index = self.shortlist_size + i - 1

gather_inds.index_fill_(0, row_indices, cluster_index)
cluster_logprob = log_softmax(cluster_output, dim=1)
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
#output.index_copy_(0, row_indices, local_logprob.squeeze(1))
output = index_copy_0dim(output, row_indices, local_logprob.squeeze(1))

used_rows += row_indices.numel()

if used_rows != batch_size:
raise RuntimeError("Target values should be in [0, {}], "
"but values in range [{}, {}] "
"were found. ".format(self.n_classes - 1,
target.min().item(),
target.max().item()))

head_output = self.head(input)
head_logprob = log_softmax(head_output, dim=1)
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
loss = (-output).mean()
if not is_batched:
output = output.squeeze(0)

output = cast_to_adapter_tensor(output)
loss = cast_to_adapter_tensor(loss)
return _ASMoutput(output, loss)

def _get_full_log_prob(self, input, head_output):
input = cast_to_ms_tensor(input)
head_output = cast_to_ms_tensor(head_output)
out = input.new_empty((head_output.shape[0], self.n_classes))
head_logprob = log_softmax(head_output, dim=1)

out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size]

for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
cluster_output = self.tail[i](input)
cluster_logprob = log_softmax(cluster_output, dim=1)
output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1)

out[:, start_idx:stop_idx] = output_logprob

return cast_to_adapter_tensor(out)

def log_prob(self, input):
input = cast_to_ms_tensor(input)
head_output = self.head(input)
out = self._get_full_log_prob(input, head_output)
return cast_to_adapter_tensor(out)


def predict(self, input):
input = cast_to_ms_tensor(input)
head_output = self.head(input)
cast_to_adapter_tensor()
output = ms.ops.argmax(head_output, axis=1)
not_in_shortlist = (output >= self.shortlist_size)
any_in_shortlist = (output < self.shortlist_size)

if not not_in_shortlist:
return cast_to_adapter_tensor(output)

elif not any_in_shortlist:
log_prob = self._get_full_log_prob(input, head_output)
return cast_to_adapter_tensor(ms.ops.argmax(log_prob, axis=1))

else:
log_prob = self._get_full_log_prob(input[not_in_shortlist],
head_output[not_in_shortlist])
output[not_in_shortlist] = ms.ops.argmax(log_prob, axis=1)
return cast_to_adapter_tensor(output)


def index_copy_0dim(input, index, tensor):
for i in range(len(index)):
input[index[i]] = tensor[i]
return input

+ 23
- 0
ms_adapter/pytorch/nn/modules/channelshuffle.py View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import mindspore.nn as nn
from ms_adapter.pytorch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor
from .module import Module

__all__ = ['ChannelShuffle']


class ChannelShuffle(Module):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
self.channel_shuffle = nn.ChannelShuffle(self.groups)

def forward(self, input):
input = cast_to_ms_tensor(input)
out = self.channel_shuffle(input)
return cast_to_adapter_tensor(out)

def extra_repr(self):
return 'groups={}'.format(self.groups)

+ 43
- 0
ms_adapter/pytorch/nn/modules/fold.py View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from ms_adapter.pytorch.nn.functional import fold, unfold
from .module import Module

__all__ = ['Fold', 'Unfold']

class Fold(Module):
#TODO: do not support on Ascend
def __init__(self, output_size, kernel_size, dilation=1, padding=0, stride=1):
super(Fold, self).__init__()
self.output_size = output_size
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding
self.stride = stride

def forward(self, input):
return fold(input, self.output_size, self.kernel_size, self.dilation, self.padding, self.stride)

def extra_repr(self):
return 'output_size={output_size}, kernel_size={kernel_size}, ' \
'dilation={dilation}, padding={padding}, stride={stride}'.format(
**self.__dict__
)


class Unfold(Module):
# TODO: do not support on GPU
def __init__(self, kernel_size, dilation=1, padding=0, stride=1):
super(Unfold, self).__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding
self.stride = stride

def forward(self, input):
return unfold(input, self.kernel_size, self.dilation, self.padding, self.stride)

def extra_repr(self):
return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
' stride={stride}'.format(**self.__dict__)

+ 110
- 1
ms_adapter/pytorch/nn/modules/loss.py View File

@@ -2,8 +2,10 @@
# -*- coding: utf-8 -*-
import warnings

import mindspore as ms
import mindspore.nn as nn
from ms_adapter.utils import unsupported_attr
from ms_adapter.pytorch.tensor import Tensor
from ms_adapter.pytorch.tensor import Tensor, cast_to_ms_tensor, cast_to_adapter_tensor
import ms_adapter.pytorch.nn.functional as F
from .module import Module

@@ -13,6 +15,7 @@ __all__ = [
'L1Loss',
'MSELoss',
'CrossEntropyLoss',
'CTCLoss',
'NLLLoss',
'KLDivLoss',
'BCELoss',
@@ -22,6 +25,13 @@ __all__ = [
'CosineEmbeddingLoss',
'MultiMarginLoss',
'TripletMarginLoss',
'PoissonNLLLoss',
'GaussianNLLLoss',
'HingeEmbeddingLoss',
'MarginRankingLoss',
'MultiLabelMarginLoss',
'MultiLabelSoftMarginLoss',
'TripletMarginWithDistanceLoss',
]

class _Loss(Module):
@@ -214,3 +224,102 @@ class TripletMarginLoss(_Loss):
def forward(self, anchor, positive, negative):
return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p,
eps=self.eps, swap=self.swap, reduction=self.reduction)


class PoissonNLLLoss(_Loss):
def __init__(self, log_input=True, full=False, size_average=None, eps=1e-8, reduce=None, reduction='mean'):
super(PoissonNLLLoss, self).__init__(size_average, reduce, reduction)
self.log_input = log_input
self.full = full
self.eps = eps

def forward(self, log_input, target):
return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full,
eps=self.eps, reduction=self.reduction)


class GaussianNLLLoss(_Loss):
def __init__(self, *, full=False, eps=1e-6, reduction='mean'):
super(GaussianNLLLoss, self).__init__(None, None, reduction)
self.full = full
self.eps = eps
self.gaussian_nll_loss = nn.GaussianNLLLoss(full=self.full, eps=self.eps, reduction=self.reduction)

def forward(self, input, target, var):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
var = cast_to_ms_tensor(var)
out = self.gaussian_nll_loss(input, target, var)
return cast_to_adapter_tensor(out)


class MarginRankingLoss(_Loss):
def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'):
super(MarginRankingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
self.margin_ranking_loss = nn.MarginRankingLoss(self.margin, self.reduction)

def forward(self, input1, input2, target):
return self.margin_ranking_loss(input1, input2, target)


class HingeEmbeddingLoss(_Loss):
def __init__(self, margin=1.0, size_average=None, reduce=None, reduction='mean'):
super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
self.hinge_embedding_loss = nn.HingeEmbeddingLoss(margin=self.margin, reduction=self.reduction)

def forward(self, input, target):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
out = self.hinge_embedding_loss(input, target)
return cast_to_adapter_tensor(out)


class MultiLabelMarginLoss(_Loss):
def __init__(self, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction)
self.multilabel_margin_loss = ms.ops.MultilabelMarginLoss(reduction=self.reduction)

def forward(self, input, target):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
if target.dtype == ms.int64:
target = target.astype(ms.int32)
out = self.multilabel_margin_loss(input, target)
return cast_to_adapter_tensor(out)


class MultiLabelSoftMarginLoss(_WeightedLoss):
def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'):
super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction)
self.multilabel_soft_margin_loss = nn.MultiLabelSoftMarginLoss(weight=self.weight, reduction=self.reduction)

def forward(self, input, target):
input = cast_to_ms_tensor(input)
target = cast_to_ms_tensor(target)
out = self.multilabel_soft_margin_loss(input, target)
return cast_to_adapter_tensor(out)

class TripletMarginWithDistanceLoss(_Loss):
def __init__(self, *, distance_function=None,
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction)
self.distance_function = distance_function
self.margin = margin
self.swap = swap

def forward(self, anchor, positive, negative):
return F.triplet_margin_with_distance_loss(anchor, positive, negative,
distance_function=self.distance_function,
margin=self.margin, swap=self.swap, reduction=self.reduction)

class CTCLoss(_Loss):
def __init__(self, blank=0, reduction='mean', zero_infinity=False):
super(CTCLoss, self).__init__(reduction=reduction)
self.blank = blank
self.zero_infinity = zero_infinity

def forward(self, log_probs, targets, input_lengths, target_lengths):
return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction,
self.zero_infinity)

+ 9
- 5
ms_adapter/pytorch/nn/modules/padding.py View File

@@ -4,7 +4,7 @@ from mindspore import nn
from ms_adapter.pytorch.tensor import cast_to_ms_tensor, cast_to_adapter_tensor
from .module import Module

__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d',
__all__ = ['ConstantPad1d', 'ConstantPad2d', 'ConstantPad3d', 'ReflectionPad1d', 'ReflectionPad2d', 'ReflectionPad3d',
'ZeroPad2d', 'ReplicationPad1d', 'ReplicationPad2d', 'ReplicationPad3d']


@@ -226,10 +226,14 @@ class ReflectionPad3d(_ReflectionPadNd):

"""

# def __init__(self, padding):
# super(ReflectionPad3d, self).__init__(padding)
# TODO: mindspore don't has nn.ReflectionPad3d API now.
# self.pad_fun = nn.ReflectionPad3d(self.padding)
def __init__(self, padding):
super(ReflectionPad3d, self).__init__(padding)
self.pad_fun = nn.ReflectionPad3d(self.padding)

def forward(self, input):
input = cast_to_ms_tensor(input)
output = self.pad_fun(input)
return cast_to_adapter_tensor(output)


class ZeroPad2d(ConstantPad2d):


+ 0
- 0
ms_adapter/pytorch/nn/modules/transformer.py View File


+ 21
- 6
testing/ut/pytorch/functional/test_math.py View File

@@ -679,11 +679,25 @@ def test_digamma():
torch_out = torch.digamma(torch_tensor)

ms_tensor = ms_torch.tensor(np_array)
#ms_out = ms_torch.digamma(ms_tensor)
ms_out = ms_torch.digamma(ms_tensor)

#assert np.allclose(ms_out.asnumpy(), torch_out.numpy())
#assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
#assert ms_out.asnumpy().shape == torch_out.numpy().shape
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_lgamma():
np_array = np.random.rand(1, 1, 1, 1, 2, 3, 2).astype(np.float64) - 0.5
np_array = np_array * 20
torch_tensor = torch.tensor(np_array)
torch_out = torch.lgamma(torch_tensor)

ms_tensor = ms_torch.tensor(np_array)
ms_out = ms_torch.lgamma(ms_tensor)

assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_erf():
@@ -1595,7 +1609,7 @@ if __name__ == '__main__':
test_sign2()
test_pow1()
test_pow2()
test_pow3()
#test_pow3()
test_pow4()
test_exp1()
test_exp2()
@@ -1628,7 +1642,8 @@ if __name__ == '__main__':
test_copysign()
test_cosh()
test_deg2rad()
# test_digamma()
test_digamma()
test_lgamma()
test_erf()
test_erfc()
test_erfinv()


+ 111
- 0
testing/ut/pytorch/nn/functional/test_conv.py View File

@@ -145,6 +145,112 @@ def test_conv2d3():
assert ms_out.asnumpy().shape == torch_out.numpy().shape




def test_conv3d1():
np_input = np.random.randn(1, 4, 5, 5, 5).astype(np.float32)
np_weight = np.random.randn(8, 4, 3, 3, 3).astype(np.float32)

torch_tensor = torch.tensor(np_input)
torch_weight = torch.tensor(np_weight)
torch_out = torch.nn.functional.conv3d(torch_tensor, torch_weight)

ms_tensor = ms_torch.tensor(np_input)
ms_weight = ms_torch.tensor(np_weight)
ms_out = ms_torch.nn.functional.conv3d(ms_tensor, ms_weight)

if ms.get_context('device_target') == 'Ascend':
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-2, atol=1e-2)
else:
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_conv3d2():
out_channel = 1
np_input = np.random.randn(3, 8, 9, 9, 9).astype(np.float32)
np_weight = np.random.randn(out_channel, 8, 3, 5, 4).astype(np.float32)
np_bias = np.ones(out_channel).astype(np.float32)*0.5

torch_tensor = torch.tensor(np_input)
torch_weight = torch.tensor(np_weight)
torch_bias = torch.tensor(np_bias)
torch_out = torch.nn.functional.conv3d(torch_tensor, torch_weight, torch_bias, 7, 3, (2,), 1)

ms_tensor = ms_torch.tensor(np_input)
ms_weight = ms_torch.tensor(np_weight)
ms_bias = ms_torch.tensor(np_bias)
ms_out = ms_torch.nn.functional.conv3d(ms_tensor, ms_weight, ms_bias, 7, 3, (2,), 1)

if ms.get_context('device_target') == 'Ascend':
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-2, atol=1e-2)
else:
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_conv3d3():
out_channel = 8
np_input = np.random.randn(4, 16, 18, 18, 18).astype(np.float64)
np_weight = np.random.randn(out_channel, 4, 5, 5, 5).astype(np.float64)
np_bias = np.ones(out_channel).astype(np.float32)*0.5

torch_tensor = torch.tensor(np_input)
torch_weight = torch.tensor(np_weight)
torch_bias = torch.tensor(np_bias)
torch_out = torch.nn.functional.conv3d(torch_tensor, torch_weight, torch_bias, (2, 2, 1), (4, 6, 2), (1, 3, 3), 4)

ms_tensor = ms_torch.tensor(np_input)
ms_weight = ms_torch.tensor(np_weight)
ms_bias = ms_torch.tensor(np_bias)
ms_out = ms_torch.nn.functional.conv3d(ms_tensor, ms_weight, ms_bias, (2, 2, 1), (4, 6, 2), (1, 3, 3), 4)

if ms.get_context('device_target') == 'Ascend':
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-2, atol=1e-2)
else:
assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_unfold():
np_input = np.random.randn(7, 8, 9, 10)

torch_tensor = torch.tensor(np_input)
torch_out = torch.nn.functional.unfold(torch_tensor, (2, 3), 1, (1, 2), (2, 1))

ms_tensor = ms_torch.tensor(np_input)
ms_out = ms_torch.nn.functional.unfold(ms_tensor, (2, 3), 1, (1, 2), (2, 1))

assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_fold():
np_input1 = np.random.randn(7, 8, 24)
np_input2 = np.random.randn(18, 6)

torch_tensor1 = torch.tensor(np_input1)
torch_tensor2 = torch.tensor(np_input2)
torch_out1 = torch.nn.functional.fold(torch_tensor1, (4, 5), (2, 2), 1, (1, 2), (2, 1))
torch_out2 = torch.nn.functional.fold(torch_tensor2, (7, 4), 3, 2, 2, 3)

ms_tensor1 = ms_torch.tensor(np_input1)
ms_tensor2 = ms_torch.tensor(np_input2)
ms_out1 = ms_torch.nn.functional.fold(ms_tensor1, (4, 5), (2, 2), 1, (1, 2), (2, 1))
ms_out2 = ms_torch.nn.functional.fold(ms_tensor2, (7, 4), 3, 2, 2, 3)

assert np.allclose(ms_out1.asnumpy(), torch_out1.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out1.asnumpy().dtype == torch_out1.numpy().dtype
assert ms_out1.asnumpy().shape == torch_out1.numpy().shape
assert np.allclose(ms_out2.asnumpy(), torch_out2.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out2.asnumpy().dtype == torch_out2.numpy().dtype
assert ms_out2.asnumpy().shape == torch_out2.numpy().shape


if __name__ == '__main__':
test_conv1d1()
test_conv1d2()
@@ -152,3 +258,8 @@ if __name__ == '__main__':
test_conv2d1()
test_conv2d2()
test_conv2d3()
test_conv3d1()
test_conv3d2()
test_conv3d3()
test_unfold()
test_fold()

+ 25
- 1
testing/ut/pytorch/nn/functional/test_linear.py View File

@@ -59,7 +59,31 @@ def test_linear3():
assert ms_out.numpy().dtype == torch_out.numpy().dtype


def test_bilinear():
data1 = np.random.randn(2, 2, 3)
data2 = np.random.randn(2, 2, 5)
weight = np.random.randn(4, 3, 5)
bias = np.random.randn(4)

torch_input1 = torch.tensor(data1)
torch_input2 = torch.tensor(data2)
torch_weight = torch.tensor(weight)
torch_bias = torch.tensor(bias)
torch_out = torch.nn.functional.bilinear(torch_input1, torch_input2, torch_weight, torch_bias)

ms_input1 = ms_torch.tensor(data1)
ms_input2 = ms_torch.tensor(data2)
ms_weight = ms_torch.tensor(weight)
ms_bias = ms_torch.tensor(bias)
ms_out = ms_torch.nn.functional.bilinear(ms_input1, ms_input2, ms_weight, ms_bias)

assert np.allclose(ms_out.asnumpy(), torch_out.numpy())
assert ms_out.shape == torch_out.shape
assert ms_out.numpy().dtype == torch_out.numpy().dtype


if __name__ == '__main__':
test_linear1()
test_linear2()
test_linear3()
test_linear3()
test_bilinear()

+ 174
- 0
testing/ut/pytorch/nn/functional/test_loss.py View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import torch
import mindspore as ms
from mindspore import context

import ms_adapter.pytorch as ms_torch

context.set_context(mode=ms.PYNATIVE_MODE)

def test_ctc_loss():
np_log_probs = np.random.randn(24, 2, 10)
np_targets = np.random.rand(2, 10)*10
np_input_length = np.array([8, 10])
np_target_length = np.array([5, 6])

torch_log_probs = torch.tensor(np_log_probs)
torch_targets = torch.tensor(np_targets)
torch_input_length = torch.tensor(np_input_length)
torch_target_length = torch.tensor(np_target_length)
result_torch = torch.nn.functional.ctc_loss(torch_log_probs, torch_targets, torch_input_length, torch_target_length)

ms_log_probs = ms_torch.tensor(np_log_probs)
ms_targets = ms_torch.tensor(np_targets)
ms_input_length = ms_torch.tensor(np_input_length)
ms_target_length = ms_torch.tensor(np_target_length)
result_ms = ms_torch.nn.functional.ctc_loss(ms_log_probs, ms_targets, ms_input_length, ms_target_length)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

def test_gaussian_nll_loss():
np_input = np.random.randn(5, 6, 7, 8)
np_targets = np.random.rand(5, 6, 7, 8)
np_var = np.ones([5, 6, 7, 8])

torch_log_probs = torch.tensor(np_input)
torch_targets = torch.tensor(np_targets)
torch_var = torch.tensor(np_var)
result_torch = torch.nn.functional.gaussian_nll_loss(torch_log_probs, torch_targets, torch_var)

ms_log_probs = ms_torch.tensor(np_input)
ms_targets = ms_torch.tensor(np_targets)
ms_var = ms_torch.tensor(np_var)
result_ms = ms_torch.nn.functional.gaussian_nll_loss(ms_log_probs, ms_targets, ms_var)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

def test_hinge_embedding_loss():
np_input = np.random.randn(5, 6, 7, 8)
np_targets = np.sign(np.random.rand(5, 6, 7, 8))

torch_log_probs = torch.tensor(np_input)
torch_targets = torch.tensor(np_targets)
result_torch = torch.nn.functional.hinge_embedding_loss(torch_log_probs, torch_targets)

ms_log_probs = ms_torch.tensor(np_input)
ms_targets = ms_torch.tensor(np_targets)
result_ms = ms_torch.nn.functional.hinge_embedding_loss(ms_log_probs, ms_targets)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape


def test_margin_ranking_loss():
np_input1 = np.random.randn(5, 6, 7, 8)
np_input2 = np.random.randn(5, 6, 7, 8)
np_targets = np.sign(np.random.rand(5, 6, 7, 8))

torch_input1 = torch.tensor(np_input1)
torch_input2 = torch.tensor(np_input2)
torch_targets = torch.tensor(np_targets)
result_torch = torch.nn.functional.margin_ranking_loss(torch_input1, torch_input2, torch_targets)

ms_input1 = ms_torch.tensor(np_input1)
ms_input2 = ms_torch.tensor(np_input2)
ms_targets = ms_torch.tensor(np_targets)
result_ms = ms_torch.nn.functional.margin_ranking_loss(ms_input1, ms_input2, ms_targets)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

# def test_multilabel_margin_loss():
# np_input = np.random.randn(5, 10)
# np_targets = np.random.rand(5, 10)*10-1
#
# torch_input = torch.tensor(np_input)
# torch_targets = torch.tensor(np_targets, dtype=torch.int64)
# result_torch = torch.nn.functional.multilabel_margin_loss(torch_input, torch_targets)
#
# ms_input = ms_torch.tensor(np_input)
# ms_targets = ms_torch.tensor(np_targets, dtype=ms.int64)
# result_ms = ms_torch.nn.functional.multilabel_margin_loss(ms_input, ms_targets)
#
# assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
# assert result_ms.shape == result_torch.shape


def test_multilabel_soft_margin_loss():
np_input = np.array([[0.3, 0.6, 0.6], [0.9, 0.4, 0.2]])
np_targets = np.array([[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]])

torch_input = torch.tensor(np_input)
torch_targets = torch.tensor(np_targets)
result_torch = torch.nn.functional.multilabel_soft_margin_loss(torch_input, torch_targets)

ms_input = ms_torch.tensor(np_input)
ms_targets = ms_torch.tensor(np_targets)
result_ms = ms_torch.nn.functional.multilabel_soft_margin_loss(ms_input, ms_targets)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

def test_multi_margin_loss():
np_input = np.array([[0.1, 0.2, 0.4, 0.8]])
np_targets = np.array([3])

torch_input = torch.tensor(np_input)
torch_targets = torch.tensor(np_targets, dtype=torch.int64)
result_torch = torch.nn.functional.multi_margin_loss(torch_input, torch_targets)

ms_input = ms_torch.tensor(np_input)
ms_targets = ms_torch.tensor(np_targets, dtype=ms.int64)
result_ms = ms_torch.nn.functional.multi_margin_loss(ms_input, ms_targets)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

def test_huber_loss():
np_input = np.array([1.0, 2, 10, 2])
np_targets = np.array([1.0, 5, 1, 20])

torch_input = torch.tensor(np_input)
torch_targets = torch.tensor(np_targets)
result_torch = torch.nn.functional.huber_loss(torch_input, torch_targets)

ms_input = ms_torch.tensor(np_input)
ms_targets = ms_torch.tensor(np_targets)
result_ms = ms_torch.nn.functional.huber_loss(ms_input, ms_targets)

assert np.allclose(result_ms.asnumpy(), result_torch.numpy())
assert result_ms.shape == result_torch.shape

def test_triplet_margin_loss():
np_anc = np.random.randn(100, 128)
np_pos = np.random.randn(100, 128)
np_neg = np.random.randn(100, 128)

torch_anc = torch.tensor(np_anc, requires_grad=True)
torch_pos = torch.tensor(np_pos, requires_grad=True)
torch_neg = torch.tensor(np_neg, requires_grad=True)
result_torch = torch.nn.functional.triplet_margin_loss(torch_anc, torch_pos, torch_neg)

ms_anc = ms_torch.tensor(np_anc, requires_grad=True)
ms_pos = ms_torch.tensor(np_pos, requires_grad=True)
ms_neg = ms_torch.tensor(np_neg, requires_grad=True)
result_ms = ms_torch.nn.functional.triplet_margin_loss(ms_anc, ms_pos, ms_neg)

assert np.allclose(result_ms.asnumpy(), result_torch.detach().numpy())
assert result_ms.shape == result_torch.shape

if __name__ == '__main__':
test_ctc_loss()
test_gaussian_nll_loss()
test_hinge_embedding_loss()
test_margin_ranking_loss()
# test_multilabel_margin_loss()
test_multilabel_soft_margin_loss()
test_multi_margin_loss()
test_huber_loss()
test_triplet_margin_loss()

+ 35
- 0
testing/ut/pytorch/nn/test_activation.py View File

@@ -573,6 +573,39 @@ def test_prelu():
assert np.allclose(torch_out.detach().numpy(), ms_torch_out.detach().numpy())


def test_softplus():
ms_net = nn.Softplus(beta=2, threshold=15)
torch_net = torch.nn.Softplus(beta=2, threshold=15)
data = np.random.randn(2, 3, 4, 5).astype(np.float32)*50

torch_input = torch.Tensor(data)
torch_output = torch_net(torch_input)

ms_input = Tensor(data)
ms_output = ms_net(ms_input)
assert np.allclose(ms_output.asnumpy(), torch_output.numpy(), atol=1e-5)


def test_softmax2d():
ms_net = nn.Softmax2d()
torch_net = torch.nn.Softmax2d()
data1 = np.random.randn(2, 3, 4, 5).astype(np.float32)*50
data2 = np.random.randn(3, 4, 5).astype(np.float32)*50

torch_input1 = torch.Tensor(data1)
torch_input2 = torch.Tensor(data2)
torch_output1 = torch_net(torch_input1)
torch_output2 = torch_net(torch_input2)

ms_input1 = Tensor(data1)
ms_input2 = Tensor(data2)
ms_output1 = ms_net(ms_input1)
ms_output2 = ms_net(ms_input2)

assert np.allclose(ms_output1.asnumpy(), torch_output1.numpy())
assert np.allclose(ms_output2.asnumpy(), torch_output2.numpy())


if __name__ == '__main__':
test_relu1()
test_relu2()
@@ -609,3 +642,5 @@ if __name__ == '__main__':
test_multi_head_attention1()
test_multi_head_attention2()
test_prelu()
test_softplus()
test_softmax2d()

+ 63
- 0
testing/ut/pytorch/nn/test_adaptive.py View File

@@ -0,0 +1,63 @@
import numpy as np
import torch

import ms_adapter.pytorch as ms_torch
from ms_adapter.pytorch.nn import AdaptiveLogSoftmaxWithLoss
from mindspore import context
import mindspore as ms
from ms_adapter.pytorch.nn import Parameter

#context.set_context(mode=ms.GRAPH_MODE)
context.set_context(mode=ms.PYNATIVE_MODE)


def test_adaptive_logsoftmax_withloss():
n = 50
in_fea = 200
n_class = 100
cutoffs = [10, 30, 60]
seed = 1000
np.random.seed(seed)
data = np.random.rand(n, in_fea).astype(np.float32)*10
target = np.random.rand(n)*n_class
target = target.astype(np.int64)

headweight = np.random.rand(13, 200)
weight00 = np.random.rand(50, 200)
weight01 = np.random.rand(20, 50)
weight10 = np.random.rand(12, 200)
weight11 = np.random.rand(30, 12)
weight20 = np.random.rand(3, 200)
weight21 = np.random.rand(40, 3)

torch_net = torch.nn.AdaptiveLogSoftmaxWithLoss(in_fea, n_class, cutoffs)
ms_net = AdaptiveLogSoftmaxWithLoss(in_fea, n_class, cutoffs)
torch_net.head.weight = torch.nn.Parameter(torch.tensor(headweight, dtype=torch.float32))
torch_net.tail[0][0].weight = torch.nn.Parameter(torch.tensor(weight00, dtype=torch.float32))
torch_net.tail[0][1].weight = torch.nn.Parameter(torch.tensor(weight01, dtype=torch.float32))
torch_net.tail[1][0].weight = torch.nn.Parameter(torch.tensor(weight10, dtype=torch.float32))
torch_net.tail[1][1].weight = torch.nn.Parameter(torch.tensor(weight11, dtype=torch.float32))
torch_net.tail[2][0].weight = torch.nn.Parameter(torch.tensor(weight20, dtype=torch.float32))
torch_net.tail[2][1].weight = torch.nn.Parameter(torch.tensor(weight21, dtype=torch.float32))
ms_net.head.weight = Parameter(ms_torch.tensor(headweight, dtype=torch.float32))
ms_net.tail[0][0].weight = Parameter(ms_torch.tensor(weight00, dtype=torch.float32))
ms_net.tail[0][1].weight = Parameter(ms_torch.tensor(weight01, dtype=torch.float32))
ms_net.tail[1][0].weight = Parameter(ms_torch.tensor(weight10, dtype=torch.float32))
ms_net.tail[1][1].weight = Parameter(ms_torch.tensor(weight11, dtype=torch.float32))
ms_net.tail[2][0].weight = Parameter(ms_torch.tensor(weight20, dtype=torch.float32))
ms_net.tail[2][1].weight = Parameter(ms_torch.tensor(weight21, dtype=torch.float32))

torch_input = torch.tensor(data)
torch_target = torch.tensor(target).long()
ms_input = ms_torch.tensor(data)
ms_target = ms_torch.tensor(target)

torch_out, torch_loss = torch_net(torch_input, torch_target)
ms_out, ms_loss = ms_net(ms_input, ms_target)

# assert torch_out.shape == ms_out.shape
# assert np.allclose(torch_loss.detach().numpy(), ms_loss.asnumpy())
# assert np.allclose(torch_out.detach().numpy(), ms_out.asnumpy())

if __name__ == '__main__':
test_adaptive_logsoftmax_withloss()

+ 38
- 0
testing/ut/pytorch/nn/test_channelshuffle.py View File

@@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import torch
import mindspore as ms
from mindspore import context

import ms_adapter.pytorch as ms_torch

context.set_context(mode=ms.GRAPH_MODE)

def test_channelshuffle():
np_data1 = np.random.randn(2, 9, 3, 4)
np_data2 = np.random.randn(1, 6, 12, 2, 2)

torch_input1 = torch.tensor(np_data1)
torch_input2 = torch.tensor(np_data2)
ms_input1 = ms_torch.tensor(np_data1)
ms_input2 = ms_torch.tensor(np_data2)

torch_shuffle = torch.nn.ChannelShuffle(3)
ms_shuffle = ms_torch.nn.ChannelShuffle(3)

torch_out1 = torch_shuffle(torch_input1)
torch_out2 = torch_shuffle(torch_input2)
ms_out1 = ms_shuffle(ms_input1)
ms_out2 = ms_shuffle(ms_input2)

assert np.allclose(torch_out1.detach().numpy(), ms_out1.numpy())
assert torch_out1.detach().numpy().dtype == ms_out1.numpy().dtype
assert torch_out1.shape == ms_out1.shape
assert np.allclose(torch_out2.detach().numpy(), ms_out2.numpy())
assert torch_out2.detach().numpy().dtype == ms_out2.numpy().dtype
assert torch_out2.shape == ms_out2.shape

if __name__ == '__main__':
test_channelshuffle()

+ 54
- 0
testing/ut/pytorch/nn/test_fold.py View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import mindspore as ms
import ms_adapter.pytorch as ms_torch
import torch
import numpy as np
from mindspore import context

def test_unfold():
np_input = np.random.randn(7, 8, 9, 10)

torch_tensor = torch.tensor(np_input)
torch_unfold = torch.nn.Unfold((2, 3), 1, (1, 2), (2, 1))
torch_out = torch_unfold(torch_tensor)

ms_tensor = ms_torch.tensor(np_input)
ms_unfold = ms_torch.nn.Unfold((2, 3), 1, (1, 2), (2, 1))
ms_out = ms_unfold(ms_tensor)

assert np.allclose(ms_out.asnumpy(), torch_out.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype
assert ms_out.asnumpy().shape == torch_out.numpy().shape


def test_fold():
np_input1 = np.random.randn(7, 8, 24)
np_input2 = np.random.randn(18, 6)

torch_tensor1 = torch.tensor(np_input1)
torch_tensor2 = torch.tensor(np_input2)
torch_fold1 = torch.nn.Fold((4, 5), (2, 2), 1, (1, 2), (2, 1))
torch_fold2 = torch.nn.Fold((7, 4), 3, 2, 2, 3)
torch_out1 = torch_fold1(torch_tensor1)
torch_out2 = torch_fold2(torch_tensor2)

ms_tensor1 = ms_torch.tensor(np_input1)
ms_tensor2 = ms_torch.tensor(np_input2)
ms_fold1 = ms_torch.nn.Fold((4, 5), (2, 2), 1, (1, 2), (2, 1))
ms_fold2 = ms_torch.nn.Fold((7, 4), 3, 2, 2, 3)
ms_out1 = ms_fold1(ms_tensor1)
ms_out2 = ms_fold2(ms_tensor2)

assert np.allclose(ms_out1.asnumpy(), torch_out1.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out1.asnumpy().dtype == torch_out1.numpy().dtype
assert ms_out1.asnumpy().shape == torch_out1.numpy().shape
assert np.allclose(ms_out2.asnumpy(), torch_out2.numpy(), rtol=1e-3, atol=1e-5)
assert ms_out2.asnumpy().dtype == torch_out2.numpy().dtype
assert ms_out2.asnumpy().shape == torch_out2.numpy().shape


if __name__ == '__main__':
test_unfold()
test_fold()

+ 180
- 4
testing/ut/pytorch/nn/test_loss.py View File

@@ -563,12 +563,12 @@ def test_multi_margin_loss_none():
y = np.array([3])

torch_input1 = torch.tensor(x)
torch_input2 = torch.tensor(y)
torch_input2 = torch.tensor(y, dtype=torch.int64)
torch_loss = torch.nn.MultiMarginLoss(p=2, reduction='none')
result_torch = torch_loss(torch_input1, torch_input2)

ms_logits1 = ms_torch.tensor(x)
ms_logits2 = ms_torch.tensor(y)
ms_logits2 = ms_torch.tensor(y, dtype=ms.int64)
ms_loss = ms_torch.nn.MultiMarginLoss(p=2, reduction='none')
result_ms = ms_loss(ms_logits1, ms_logits2)

@@ -582,13 +582,13 @@ def test_multi_margin_loss_weight():
weight = np.array([0.2, 0.3, 0.4, 0.1])

torch_input1 = torch.tensor(x)
torch_input2 = torch.tensor(y)
torch_input2 = torch.tensor(y, dtype=torch.int64)
torch_weight = torch.tensor(weight)
torch_loss = torch.nn.MultiMarginLoss(weight=torch_weight)
result_torch = torch_loss(torch_input1, torch_input2)

ms_logits1 = ms_torch.tensor(x)
ms_logits2 = ms_torch.tensor(y)
ms_logits2 = ms_torch.tensor(y, dtype=ms.int64)
ms_weight = ms_torch.tensor(weight)
ms_loss = ms_torch.nn.MultiMarginLoss(weight=ms_weight)
result_ms = ms_loss(ms_logits1, ms_logits2)
@@ -597,6 +597,173 @@ def test_multi_margin_loss_weight():
assert result_ms.asnumpy().dtype == result_torch.numpy().dtype
assert result_ms.shape == result_torch.shape

def test_poisson_nll_loss():
np_data = np.random.randn(5, 2)
np_target = np.random.randn(5, 2)

torch_input = torch.tensor(np_data, requires_grad=True)
torch_target = torch.tensor(np_target)
ms_input = ms_torch.tensor(np_data, requires_grad=True)
ms_target = ms_torch.tensor(np_target)

torch_loss = torch.nn.PoissonNLLLoss(reduction="sum")
ms_loss = ms_torch.nn.PoissonNLLLoss(reduction="sum")

torch_out = torch_loss(torch_input, torch_target)
ms_out = ms_loss(ms_input, ms_target)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

def test_gaussian_nll_loss():
np_data = np.random.randn(5, 2)
np_target = np.random.randn(5, 2)
np_var = np.ones((5, 2))

torch_input = torch.tensor(np_data, requires_grad=True)
torch_target = torch.tensor(np_target)
torch_var = torch.tensor(np_var, requires_grad=True)
ms_input = ms_torch.tensor(np_data, requires_grad=True)
ms_target = ms_torch.tensor(np_target)
ms_var = ms_torch.tensor(np_var, requires_grad=True)

torch_loss = torch.nn.GaussianNLLLoss()
ms_loss = ms_torch.nn.GaussianNLLLoss()

torch_out = torch_loss(torch_input, torch_target, torch_var)
ms_out = ms_loss(ms_input, ms_target, ms_var)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

def test_hinge_embedding_loss():
np_data = np.random.randn(5, 2)
np_target = np.sign(np.random.randn(5, 2))

torch_input = torch.tensor(np_data, requires_grad=True)
torch_target = torch.tensor(np_target)
ms_input = ms_torch.tensor(np_data, requires_grad=True)
ms_target = ms_torch.tensor(np_target)

torch_loss = torch.nn.HingeEmbeddingLoss(reduction="none")
ms_loss = ms_torch.nn.HingeEmbeddingLoss(reduction="none")

torch_out = torch_loss(torch_input, torch_target)
ms_out = ms_loss(ms_input, ms_target)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

# def test_multilabel_margin_loss():
# np_data = np.array([[0.1, 0.2, 0.4, 0.8], [0.2, 0.3, 0.5, 0.7]])
# np_target = np.array([[1, 2, 0, 3], [2, 3, -1, 1]])
#
# torch_input = torch.tensor(np_data, requires_grad=True)
# torch_target = torch.tensor(np_target, dtype=torch.int64)
# ms_input = ms_torch.tensor(np_data, requires_grad=True)
# ms_target = ms_torch.tensor(np_target, dtype=ms.int64)
#
# torch_loss = torch.nn.MultiLabelMarginLoss(reduction="none")
# ms_loss = ms_torch.nn.MultiLabelMarginLoss(reduction="none")
#
# torch_out = torch_loss(torch_input, torch_target)
# ms_out = ms_loss(ms_input, ms_target)
#
# assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
# assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
# assert torch_out.shape == ms_out.shape

def test_multilabel_soft_margin_loss():
np_data = np.random.randn(5, 2)
np_target = np.sign(np.random.randn(5, 2))

torch_input = torch.tensor(np_data, requires_grad=True)
torch_target = torch.tensor(np_target)
ms_input = ms_torch.tensor(np_data, requires_grad=True)
ms_target = ms_torch.tensor(np_target)

torch_loss = torch.nn.MultiLabelSoftMarginLoss(reduction="none")
ms_loss = ms_torch.nn.MultiLabelSoftMarginLoss(reduction="none")

torch_out = torch_loss(torch_input, torch_target)
ms_out = ms_loss(ms_input, ms_target)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

def test_triplet_margin_with_distance_loss():
np_anc = np.random.randn(5, 6, 7)
np_pos = np.random.randn(5, 6, 7)
np_neg = np.random.randn(5, 6, 7)

t_anc = torch.tensor(np_anc)
t_pos = torch.tensor(np_pos)
t_neg = torch.tensor(np_neg)
ms_anc = ms_torch.tensor(np_anc)
ms_pos = ms_torch.tensor(np_pos)
ms_neg = ms_torch.tensor(np_neg)

t_loss = torch.nn.TripletMarginWithDistanceLoss(distance_function=torch.nn.functional.cosine_similarity, swap=True)
ms_loss = ms_torch.nn.TripletMarginWithDistanceLoss(distance_function=ms_torch.nn.functional.cosine_similarity, swap=True)

torch_out = t_loss(t_anc, t_pos, t_neg)
ms_out = ms_loss(ms_anc, ms_pos, ms_neg)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

def test_ctc_loss():
np_data = np.random.randn(24, 2, 10)
np_target = np.random.rand(2, 10) * 10
np_input_length = np.array([8, 10])
np_target_length = np.array([5, 6])

torch_input = torch.tensor(np_data, requires_grad=True)
torch_target = torch.tensor(np_target)
torch_input_length = torch.tensor(np_input_length)
torch_target_length = torch.tensor(np_target_length)
ms_input = ms_torch.tensor(np_data, requires_grad=True)
ms_target = ms_torch.tensor(np_target)
ms_input_length = ms_torch.tensor(np_input_length)
ms_target_length = ms_torch.tensor(np_target_length)

torch_loss = torch.nn.CTCLoss(reduction="none")
ms_loss = ms_torch.nn.CTCLoss(reduction="none")

torch_out = torch_loss(torch_input, torch_target, torch_input_length, torch_target_length)
ms_out = ms_loss(ms_input, ms_target, ms_input_length, ms_target_length)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

def test_margin_ranking_loss():
np_data1 = np.random.randn(5, 2)
np_data2 = np.random.randn(5, 2)
np_target = np.sign(np.random.randn(5, 2))

torch_input1 = torch.tensor(np_data1, requires_grad=True)
torch_input2 = torch.tensor(np_data2, requires_grad=True)
torch_target = torch.tensor(np_target)
ms_input1 = ms_torch.tensor(np_data1, requires_grad=True)
ms_input2 = ms_torch.tensor(np_data2, requires_grad=True)
ms_target = ms_torch.tensor(np_target)

torch_loss = torch.nn.MarginRankingLoss(reduction="none")
ms_loss = ms_torch.nn.MarginRankingLoss(reduction="none")

torch_out = torch_loss(torch_input1, torch_input2, torch_target)
ms_out = ms_loss(ms_input1, ms_input2, ms_target)

assert np.allclose(torch_out.detach().numpy(), ms_out.numpy())
assert torch_out.detach().numpy().dtype == ms_out.numpy().dtype
assert torch_out.shape == ms_out.shape

if __name__ == '__main__':
test_smoothl1loss1()
test_smoothl1loss2()
@@ -638,3 +805,12 @@ if __name__ == '__main__':

test_multi_margin_loss_none()
test_multi_margin_loss_weight()

test_poisson_nll_loss()
test_gaussian_nll_loss()
test_hinge_embedding_loss()
# test_multilabel_margin_loss()
test_multilabel_soft_margin_loss()
test_triplet_margin_with_distance_loss()
test_ctc_loss()
test_margin_ranking_loss()

+ 23
- 0
testing/ut/pytorch/nn/test_padding.py View File

@@ -123,6 +123,28 @@ def test_reflection_pad_2d():
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_reflection_pad_3d():
padding = 2
pt_input_4d = torch.ones(2, 3, 3, 4)
pt_pad_fun1 = torch.nn.ReflectionPad3d(padding)
pt_pad_out1 = pt_pad_fun1(pt_input_4d)
ms_input_4d = ms_pytorch.ones(2, 3, 3, 4)
ms_pad_fun1 = ms_pytorch.nn.ReflectionPad3d(padding)
ms_pad_out1 = ms_pad_fun1(ms_input_4d)
assert (pt_pad_out1.shape == ms_pad_out1.shape)
assert np.allclose(pt_pad_out1.numpy(), ms_pad_out1.asnumpy())

padding = (1, 1, 2, 0, 3, 2)
pt_input_5d = torch.ones(1, 2, 6, 4, 5)
pt_pad_fun2 = torch.nn.ReflectionPad3d(padding)
pt_pad_out2 = pt_pad_fun2(pt_input_5d)
ms_input_5d = ms_pytorch.ones(1, 2, 6, 4, 5)
ms_pad_fun2 = ms_pytorch.nn.ReflectionPad3d(padding)
ms_pad_out2 = ms_pad_fun2(ms_input_5d)
assert (pt_pad_out2.shape == ms_pad_out2.shape)
assert np.allclose(pt_pad_out2.numpy(), ms_pad_out2.asnumpy())


def test_zero_pad_2d():
padding = 2
pt_input_3d = torch.ones(1, 3, 3)
@@ -214,6 +236,7 @@ if __name__ == '__main__':
test_constant_pad_3d()
test_reflection_pad_1d()
test_reflection_pad_2d()
test_reflection_pad_3d()
test_zero_pad_2d()
test_replication_pad_1d()
test_replication_pad_2d()


+ 0
- 0
testing/ut/pytorch/nn/test_transformer.py View File


Loading…
Cancel
Save