#873 [WIP]add embedding_bag api

Open
zoulq wants to merge 1 commits from zlq2 into master
  1. +2
    -0
      doc/torch/ConstraintList.md
  2. +2
    -0
      doc/torch/ConstraintList_en.md
  3. +2
    -0
      doc/torch/SupportedList.md
  4. +2
    -1
      doc/torch/SupportedList_en.md
  5. +112
    -1
      mindtorch/torch/nn/functional.py
  6. +1
    -0
      mindtorch/torch/nn/modules/__init__.py
  7. +87
    -1
      mindtorch/torch/nn/modules/sparse.py
  8. +80
    -1
      testing/ut/pytorch/nn/functional/test_functional.py
  9. +135
    -2
      testing/ut/pytorch/nn/test_sparse.py

+ 2
- 0
doc/torch/ConstraintList.md View File

@@ -259,6 +259,7 @@
| nn.AdaptiveMaxPool1d | Ascend上不支持`return_indices` |
| nn.AdaptiveMaxPool2d | Ascend上不支持`return_indices` |
| nn.Embedding | 1.`scale_grad_by_freq`, `sparse`不支持; 2.`norm_type`只能为2 |
| nn.EmbeddingBag | 1. 不支持图模式; 2. 入參`scale_grad_by_freq`,`per_sample_weights`和`include_last_offset`不支持配置非默认值;3. 入参`sparse`取值为Ture时不支持CPU平台 |
| nn.Upsample | 不支持`recompute_scale_factor` |
| nn.RNN | 在图模式下,`input`不支持PackedSequence类型 |
| nn.GRU | 在图模式下,`input`不支持PackedSequence类型 |
@@ -308,6 +309,7 @@
| functional.pad | 当`padding_mode`为'reflect'时,不支持填充最后三维 |
| functional.upsample | `mode`仅支持`linear`、`bilinear`、`nearest` |
| functional.ctc_loss | 入參`input_lengths` 和 `target_lengths` 不支持tuple类型 |
| functional.embedding_bag | 1. 不支持图模式; 2. 入參`scale_grad_by_freq`,`per_sample_weights`和`include_last_offset`不支持配置非默认值;3. 入参`sparse`取值为Ture时不支持CPU平台 |

### <span id="jump6">torch.linalg</span>
| MindTorch接口 | 约束条件 |


+ 2
- 0
doc/torch/ConstraintList_en.md View File

@@ -260,6 +260,7 @@ English | [简体中文](ConstraintList.md)
| nn.AdaptiveMaxPool1d | `return_indices` not support on Ascend |
| nn.AdaptiveMaxPool2d | `return_indices` not support on Ascend |
| nn.Embedding | 1. `scale_grad_by_freq`, `sparse` is not supported; 2. `norm_type` can only be 2 |
| nn.EmbeddingBag | 1. Not support GRAPH mode; 2. `scale_grad_by_freq`,`per_sample_weights` and `include_last_offset` only support default value;3. `sparse` not support on CPU. |
| nn.Upsample | Not support `recompute_scale_factor` |
| nn.RNN | Under GRAPH mode, `input` not support PackedSequence type |
| nn.GRU | Under GRAPH mode, `input` not support PackedSequence type |
@@ -309,6 +310,7 @@ English | [简体中文](ConstraintList.md)
| functional.pad | when `padding_mode` is 'reflect', not support pad last 3 dimentions |
| functional.upsample | `mode` only supports setting to `linear`,`bilinear` and `nearest` |
| functional.ctc_loss | `input_lengths` and `target_lengths` do not support tuple |
| functional.embedding_bag | 1. Not support GRAPH mode; 2. `scale_grad_by_freq`,`per_sample_weights` and `include_last_offset` only support default value;3. `sparse` not support on CPU. |

### <span id="jump6">torch.linalg</span>
| MindTorch APIs | Constraint conditions |


+ 2
- 0
doc/torch/SupportedList.md View File

@@ -1046,6 +1046,7 @@
| nn.LSTM | 部分支持 | [功能存在限制](ConstraintList.md) |
| nn.GRU | 部分支持 | [功能存在限制](ConstraintList.md) |
| nn.Embedding | 部分支持 | [功能存在限制](ConstraintList.md) |
| nn.EmbeddingBag | 部分支持 | [功能存在限制](ConstraintList.md) |
| nn.KLDivLoss | 支持 | |
| nn.MultiLabelMarginLoss | 部分支持 | 暂不支持CPU后端 |
| nn.MultiMarginLoss | 支持 | |
@@ -1264,6 +1265,7 @@
| functional.avg_pool1d | 支持 | |
| functional.scaled_dot_product_attention | 支持 | |
| functional.pad | 部分支持 | 1. 暂不支持图模式 2. [功能存在限制](ConstraintList.md) |
| functional.embedding_bag | 部分支持 | [功能存在限制](ConstraintList.md) |


### <span id="jump6">torch.linalg</span>


+ 2
- 1
doc/torch/SupportedList_en.md View File

@@ -1044,6 +1044,7 @@ English | [简体中文](SupportedList.md)
| nn.LSTM | Partly supported | [Function is constrained](ConstraintList_en.md) |
| nn.GRU | Partly Supported | [Function is constrained](ConstraintList_en.md) |
| nn.Embedding | Partly supported | [Function is constrained](ConstraintList_en.md) |
| nn.EmbeddingBag | Partly supported | [Function is constrained](ConstraintList_en.md) |
| nn.KLDivLoss | Supported | |
| nn.MultiLabelMarginLoss | Partly supported | Currently not support on CPU |
| nn.MultiMarginLoss | Supported | |
@@ -1262,7 +1263,7 @@ English | [简体中文](SupportedList.md)
| functional.avg_pool1d | Supported | |
| functional.scaled_dot_product_attention | Supported | |
| functional.pad | Partly supported | 1.Currently not support on GRAPH mode. 2.[Function is constrained](ConstraintList_en.md) |
| functional.embedding_bag | Partly supported | [Function is constrained](ConstraintList_en.md) |

### <span id="jump6">torch.linalg</span>
| MindTorch APIs | Status | Restrictions |


+ 112
- 1
mindtorch/torch/nn/functional.py View File

@@ -106,7 +106,8 @@ all = [
'multi_head_attention_forward',
'scaled_dot_product_attention',

'prompt_flash_attention'
'prompt_flash_attention',
'embedding_bag'
]


@@ -2767,3 +2768,113 @@ def prompt_flash_attention(query, key, value, attn_mask=None, padding_mask=None,
output_ms = pfa_op(query_ms, key_ms, value_ms, attn_mask, padding_mask, actual_seq_lengths)
pfa_output = cast_to_adapter_tensor(output_ms[0])
return pfa_output


def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2., scale_grad_by_freq=False,
mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None):

if graph_mode_condition():
raise RuntimeError("`embedding_bag` unsupported on GRAPH_MODE Now.")

if scale_grad_by_freq is not False or per_sample_weights is not None or include_last_offset is not False:
raise NotImplementedError("Currently, values other than the default are not supported.")
Erpim commented 2 months ago
Review
报错信息是否再准确一些? 除这三个外其他带有默认值的入参也只支持默认值吗?
zoulq commented 2 months ago
Review
其他参数取值没有约束

input_ms = cast_to_ms_tensor(input)
weight_ms = cast_to_ms_tensor(weight)
offsets_ms = cast_to_ms_tensor(offsets)
if weight_ms.dtype == ms.int64 and input_ms.is_floating_point():
warning("Argument order of nn.functional.embedding_bag was changed. Usage `embedding_bag(weight, input, ...)` "
" is deprecated, and should now be `embedding_bag(input, weight, ...)`.")
weight_ms, input_ms = input_ms, weight_ms

if input_ms.dim() == 2:
if offsets is not None:
raise ValueError("if input is 2D, then offsets has to be None, as input is treated is a mini-batch of"
" fixed length sequences. However, found offsets of type {}".format(type(offsets)))

offsets_ms = ms.ops.arange(0, input_ms.numel(), input_ms.shape[1], dtype=input_ms.dtype)
input_ms = input_ms.reshape(-1)
elif input_ms.dim() == 1:
if offsets is None:
raise ValueError("offsets has to be a 1D Tensor but got None")
if offsets.dim() != 1:
raise ValueError("offsets has to be a 1D Tensor")
else:
raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input_ms.dim()))

if mode == "max":
if sparse:
raise ValueError("max mode does not support sparse weights")

padding_idx = _get_embedding_padding_idx(weight_ms.shape, padding_idx)

if max_norm:
input_unique, _ = ms.ops.unique(input_ms)
for idx in input_unique:
row_tensor = weight_ms[idx]
row_norm = ms.ops.clip_by_norm(row_tensor, max_norm, norm_type)
weight_ms[idx] = row_norm

output_padding_idx = []
if padding_idx:
update_shape = (1, weight_ms.shape[-1])
update = ms.ops.zeros(update_shape, dtype=weight_ms.dtype)
indices = ms.ops.fill(ms.int32, update_shape, padding_idx)
weight_ms = ms.ops.tensor_scatter_elements(weight_ms, indices, update, axis=0)

input_np = input_ms.asnumpy()
output_padding_idx = np.where(input_np == padding_idx)[0]

#update weight, unsuppport graph mode
if isinstance(weight, Parameter):
weight.set_data(weight_ms)
else:
weight.assign_value(weight_ms)

if sparse:
# TODO:SparseGatherV2 unsupport on CPU
gather_op = ms.ops.SparseGatherV2()
else:
gather_op = ms.ops.Gather()

output = gather_op(weight_ms, input_ms, 0)

def _process(output, start, end):
if padding_idx:
index = []
for i in range(start, end):
if i not in output_padding_idx:
index += [i, ]
if len(index) == 0:
tensor_data = ms.ops.expand_dims(output[start], 0)
else:
index_tensor = ms.Tensor(index)
tensor_data = ms.ops.index_select(output, 0, index_tensor)
else:
tensor_data = output[start:end,]
if mode == "sum":
ret = tensor_data.sum(0)
elif mode == "mean":
ret = tensor_data.mean(0)
elif mode == "max":
ret = tensor_data.max(0)
else:
raise ValueError("mode has to be one of sum, mean or max")
return ret

#offset
offsets_np = offsets_ms.asnumpy()
b_len = len(offsets_np)
tmp = []
for i in range(b_len - 1):
start = int(offsets_np[i])
end = int(offsets_np[i + 1])
tmp1 = _process(output, start, end)
tmp.append(tmp1)
last_start = int(offsets_np[b_len - 1])
last_end = len(output)
tmp1 = _process(output, last_start, last_end)
tmp.append(tmp1)
ret = ms.ops.stack(tmp)

return cast_to_adapter_tensor(ret)

+ 1
- 0
mindtorch/torch/nn/modules/__init__.py View File

@@ -168,6 +168,7 @@ __all__ = [
'CosineSimilarity',

'Embedding',
'EmbeddingBag',

'PixelShuffle',
'PixelUnshuffle',


+ 87
- 1
mindtorch/torch/nn/modules/sparse.py View File

@@ -5,7 +5,7 @@ from mindtorch.utils import unsupported_attr
from mindtorch.torch.nn.modules.module import Module
from mindtorch.torch.nn.init import normal_

__all__ = ['Embedding']
__all__ = ['Embedding', 'EmbeddingBag']

class Embedding(Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
@@ -87,3 +87,89 @@ class Embedding(Module):
sparse=sparse)
embedding.weight.requires_grad = not freeze
return embedding

class EmbeddingBag(Module):
def __init__(self, num_embeddings, embedding_dim,
max_norm=None, norm_type=2., scale_grad_by_freq=False,
mode='mean', sparse=False, _weight=None,
include_last_offset=False, padding_idx=None,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super(EmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
if padding_idx is not None:
if padding_idx > 0:
if padding_idx >= self.num_embeddings:
raise ValueError('padding_idx must be within num_embeddings')
elif padding_idx < 0:
if padding_idx < -self.num_embeddings:
raise ValueError('padding_idx must be within num_embeddings')
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
if _weight is None:
self.weight = Parameter(empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
if list(_weight.shape) != [num_embeddings, embedding_dim]:
raise RuntimeError('Shape of weight does not match num_embeddings and embedding_dim')
self.weight = Parameter(_weight)
self.mode = mode
self.sparse = sparse
self.include_last_offset = include_last_offset

def reset_parameters(self):
normal_(self.weight)
self._fill_padding_idx_with_zero()

def _fill_padding_idx_with_zero(self):
if self.padding_idx is not None:
self.weight[self.padding_idx] = 0

def forward(self, input, offsets=None, per_sample_weights=None):
return Adapter_F.embedding_bag(input, self.weight, offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset,
self.padding_idx)

def extra_repr(self):
s = '{num_embeddings}, {embedding_dim}'
if self.max_norm is not None:
s += ', max_norm={max_norm}'
if self.norm_type != 2:
s += ', norm_type={norm_type}'
if self.scale_grad_by_freq is not False:
s += ', scale_grad_by_freq={scale_grad_by_freq}'
s += ', mode={mode}'
if self.padding_idx is not None:
s += ', padding_idx={padding_idx}'
return s.format(**self.__dict__)

@classmethod
def from_pretrained(cls, embeddings, freeze=True, max_norm=None,
norm_type=2., scale_grad_by_freq=False,
mode='mean', sparse=False, include_last_offset=False,
padding_idx=None):
"""Creates EmbeddingBag instance from given 2-dimensional FloatTensor."""

if embeddings.dim() != 2:
raise ValueError('Embeddings parameter is expected to be 2-dimensional')

rows, cols = embeddings.shape
embeddingbag = cls(
num_embeddings=rows,
embedding_dim=cols,
_weight=embeddings,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
embeddingbag.weight.requires_grad = not freeze
return embeddingbag

+ 80
- 1
testing/ut/pytorch/nn/functional/test_functional.py View File

@@ -7,7 +7,7 @@ import torch
from mindtorch.torch.nn.functional import interpolate, adaptive_avg_pool2d
from mindtorch.utils import is_under_ascend_context
from ....utils import SKIP_ENV_GPU, set_mode_by_env_config, param_compare, SKIP_ENV_ASCEND, SKIP_ENV_ASCEND, \
SKIP_ENV_CPU, TestNet
SKIP_ENV_CPU, TestNet, SKIP_ENV_GRAPH_MODE
set_mode_by_env_config()


@@ -571,6 +571,79 @@ def test_mul1():
# ms_out = msa_net(ms_torch.tensor([True, False, True]), ms_torch.tensor([True, False, False]))
# param_compare(torch_out, ms_out)

@SKIP_ENV_GRAPH_MODE(reason="embedding_bag unsupport on Graph mode")
def test_embedding_bag_1():
init_value = np.arange(40).reshape(10, 4).astype(np.float32)
embedding_matrix = torch.tensor(init_value)
input = torch.tensor([1,2,4,5,4,3,2,9])
offsets = torch.tensor([0,4])
pt_output = torch.nn.functional.embedding_bag(input, embedding_matrix, offsets)

ms_embedding_matrix = ms_torch.tensor(init_value)
ms_input = ms_torch.tensor([1,2,4,5,4,3,2,9])
ms_offsets = ms_torch.tensor([0,4])
mas_embdding_net = TestNet(ms_torch.nn.functional.embedding_bag)
ms_output = mas_embdding_net(ms_input, ms_embedding_matrix, ms_offsets)
param_compare(pt_output, ms_output)

@SKIP_ENV_GRAPH_MODE(reason="embedding_bag unsupport on Graph mode")
def test_embedding_bag_2():
init_value = np.random.rand(10, 3).astype(np.float32)
embedding_matrix = torch.tensor(init_value)
input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
offsets = torch.tensor([0,4])
pt_output = torch.nn.functional.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')

ms_embedding_matrix = ms_torch.tensor(init_value)
ms_input = ms_torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
ms_offsets = ms_torch.tensor([0, 4])
mas_embdding_net = TestNet(ms_torch.nn.functional.embedding_bag)
ms_output = mas_embdding_net(ms_input, ms_embedding_matrix, ms_offsets, padding_idx=2, mode='sum')
param_compare(pt_output, ms_output)

@SKIP_ENV_GRAPH_MODE(reason="embedding_bag unsupport on Graph mode")
def test_embedding_bag_3():
init_value = np.random.rand(10, 3).astype(np.float32)
embedding_matrix = torch.tensor(init_value)
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
pt_output = torch.nn.functional.embedding_bag(input, embedding_matrix, mode='max')

ms_embedding_matrix = ms_torch.tensor(init_value)
ms_input = ms_torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
mas_embdding_net = TestNet(ms_torch.nn.functional.embedding_bag)
ms_output = mas_embdding_net(ms_input, ms_embedding_matrix, mode='max')
param_compare(pt_output, ms_output)

@SKIP_ENV_GRAPH_MODE(reason="embedding_bag unsupport on Graph mode")
def test_embedding_bag_4():
init_value = np.random.rand(10, 3).astype(np.float32)
embedding_matrix = torch.tensor(init_value)
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
pt_output = torch.nn.functional.embedding_bag(input, embedding_matrix, max_norm=1.0)

ms_embedding_matrix = ms_torch.tensor(init_value)
ms_input = ms_torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
mas_embdding_net = TestNet(ms_torch.nn.functional.embedding_bag)
ms_output = mas_embdding_net(ms_input, ms_embedding_matrix, max_norm=1.0)
param_compare(pt_output, ms_output)

@SKIP_ENV_GRAPH_MODE(reason="embedding_bag unsupport on Graph mode")
@SKIP_ENV_CPU(reason="Unsupport `sparse` on CPU")
def test_embedding_bag_5():
init_value = np.arange(30).reshape(10, 3).astype(np.float32)
embedding_matrix = torch.tensor(init_value)
input = torch.tensor([1,2,4,5,4,3,2,7])

offsets = torch.tensor([0, 2, 4, 6])
pt_output = torch.nn.functional.embedding_bag(input, embedding_matrix, offsets, sparse=True)

ms_embedding_matrix = ms_torch.tensor(init_value)
ms_input = ms_torch.tensor([1,2,4,5,4,3,2,7])
ms_offsets = ms_torch.tensor([0, 2, 4, 6])
mas_embdding_net = TestNet(ms_torch.nn.functional.embedding_bag)
ms_output = mas_embdding_net(ms_input, ms_embedding_matrix, ms_offsets, sparse=True)
print(pt_output.shape, ms_output.shape)
param_compare(pt_output, ms_output)

if __name__ == '__main__':
set_mode_by_env_config()
@@ -627,3 +700,9 @@ if __name__ == '__main__':
test_constant_pad_nd()
test_mul1()
# test_mul2()

test_embedding_bag_1()
test_embedding_bag_2()
test_embedding_bag_3()
test_embedding_bag_4()
test_embedding_bag_5()

+ 135
- 2
testing/ut/pytorch/nn/test_sparse.py View File

@@ -7,7 +7,7 @@ import torch
import numpy as np
from mindspore import context

from ...utils import SKIP_ENV_ASCEND, SKIP_ENV_CPU, SKIP_ENV_GPU, set_mode_by_env_config, param_compare
from ...utils import SKIP_ENV_ASCEND, SKIP_ENV_CPU, SKIP_ENV_GPU, set_mode_by_env_config, param_compare, SKIP_ENV_GRAPH_MODE
set_mode_by_env_config()


@@ -171,6 +171,129 @@ def test_embedding_output_with_padding_idx_fp64():

param_compare(result_ms, result_torch.detach())

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embedding_bag_1():
index_np = np.array([1, 2, 1, 2, 3]).astype(np.int32)
weight_np = np.array([[2.3, 4.5, 5.2], [6.0, 7.1, 8.0], [3.5, 6., 2.9], [8.9, 4., 3.9]])

torch_input = torch.tensor(index_np, dtype=torch.long)
torch_weight = torch.tensor(weight_np)
torch_offsets = torch.tensor([0, 4], dtype=torch.long)
torch_embedding_sum = torch.nn.EmbeddingBag(4, 3, mode='sum', _weight=torch_weight)
torch_output = torch_embedding_sum(torch_input, torch_offsets)

ms_input = ms_torch.tensor(index_np, dtype=ms_torch.long)
ms_weight = ms_torch.tensor(weight_np)
ms_offsets = ms_torch.tensor([0, 4], dtype=ms_torch.long)
ms_embedding_sum = ms_torch.nn.EmbeddingBag(4, 3, mode='sum', _weight=ms_weight)
ms_output = ms_embedding_sum(ms_input, ms_offsets)

param_compare(torch_output.detach(), ms_output)

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embedding_bag_2():
index_np = np.array([1, 2, 1, 0, 2, 3, 2, 3]).astype(np.int32)
weight_np = np.array([[2.3, 4.5, 5.2], [6.0, 7.1, 8.0], [3.5, 6., 2.9], [8.9, 4., 3.9]])

torch_input = torch.tensor(index_np, dtype=torch.long)
torch_weight = torch.tensor(weight_np)
torch_offsets = torch.tensor([0, 4], dtype=torch.long)
torch_embedding = torch.nn.EmbeddingBag(4, 3, mode='mean', _weight=torch_weight, padding_idx=2)
torch_output = torch_embedding(torch_input, torch_offsets)

ms_input = ms_torch.tensor(index_np, dtype=ms_torch.long)
ms_weight = ms_torch.tensor(weight_np)
ms_offsets = ms_torch.tensor([0, 4], dtype=ms_torch.long)
ms_embedding = ms_torch.nn.EmbeddingBag(4, 3, mode='mean', _weight=ms_weight, padding_idx=2)
ms_output = ms_embedding(ms_input, ms_offsets)
param_compare(torch_output.detach(), ms_output)

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embedding_bag_grad():
index_np = np.array([[1, 2, 3], [0, 2, 1]]).astype(np.int32)
weight_np = np.array([[2.3, 4.5], [6.0, 7.1], [3.5, 6.], [8.9, 4]]).astype(np.float32)

ms_index = ms_torch.tensor(index_np)
ms_weight = ms_torch.tensor(weight_np)
net = ms_torch.nn.EmbeddingBag(4, 2, _weight=ms_weight)
result_ms = net(ms_index)
train_net = TrainNet(net)
train_net.set_grad()
grad_fn = ms.value_and_grad(train_net, grad_position=None, weights=train_net.trainable_params())
_, grads = grad_fn(ms_index)

assert not np.allclose(grads[0][1].asnumpy(), ms.ops.ZerosLike()(grads[0][1]).asnumpy())

torch_index = torch.tensor(index_np)
torch_weight = torch.tensor(weight_np)
result_torch = torch.nn.EmbeddingBag(4, 2, _weight=torch_weight)(torch_index)
param_compare(result_ms.detach(), result_torch.detach())

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embedding_bag_functional():
a = ms_torch.tensor([[1, 3, 2], [0, 2, 1]], dtype=ms_torch.long)
embeddings = ms_torch.rand(4, 3, requires_grad=True)

embed_old = ms_torch.nn.EmbeddingBag(4, 3)
embed_old.weight = ms_torch.nn.Parameter(embeddings)
res_old = embed_old(a)

res_F = ms_torch.nn.functional.embedding_bag(a, embeddings)
param_compare(res_old, res_F)

embed_old = ms_torch.nn.EmbeddingBag(4, 3)
embed_old = embed_old.from_pretrained(embeddings, padding_idx=2)
res_old = embed_old(a)
res_F = ms_torch.nn.functional.embedding_bag(a, embeddings, padding_idx=2)
param_compare(res_old, res_F)

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embeddingbag_from_pretrained():
a = ms_torch.tensor([[1., 2., 3.], [4., 5., 6.]])
embeddingbag = ms_torch.nn.EmbeddingBag.from_pretrained(a)
param_compare(a, embeddingbag.weight.data)

input = ms_torch.LongTensor([[0, 1]])
output = embeddingbag(input)
param_compare(a.mean(0, keepdim=True), output)

@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
def test_embeddingbag_from_pretrained_options():
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
opts = {
"max_norm": 2.,
"scale_grad_by_freq": False,
"mode": "max",
"sparse": False
}
embeddingbag = torch.nn.EmbeddingBag.from_pretrained(a, **opts)

input = torch.LongTensor([[0, 1]])
output = embeddingbag(input)
weight = embeddingbag.weight
out1 = weight.max(0, keepdim=True)[0]
param_compare(out1, output)

out2 = weight.ne(torch.arange(1, 7, dtype=a.dtype).view(2, 3)).all()
assert (out2 == True)

@SKIP_ENV_GPU(reason="EmbeddingBag unsupport empty input on Graph mode")
@SKIP_ENV_ASCEND(reason="EmbeddingBag unsupport empty input on Graph mode")
@SKIP_ENV_GRAPH_MODE(reason="EmbeddingBag unsupport on Graph mode")
@SKIP_ENV_CPU(reason="EmbeddingBag sparse=True unsupport on CPU.")
def test_embedding_bag_empty_input():
m = 4
n = 3
x = ms_torch.tensor([], dtype=ms_torch.int32) #sparse=True only support int16/int32/int64
for sparse in [True, False]:
Embed = ms_torch.nn.EmbeddingBag(m, n, sparse=sparse)
output1 = Embed(input=x, offsets=ms_torch.tensor([0]))
param_compare(output1, ms_torch.zeros_like(output1))

output2 = Embed(input=x, offsets=ms_torch.tensor([0, 0]))
param_compare(output2, ms_torch.zeros_like(output2))


if __name__ == '__main__':
set_mode_by_env_config()
test_embedding()
@@ -179,4 +302,14 @@ if __name__ == '__main__':
test_embedding_weight_grad_with_padding_idx()
test_embedding_output_with_padding_idx()
test_embedding_weight_grad_with_padding_idx_fp64()
test_embedding_output_with_padding_idx_fp64()
test_embedding_output_with_padding_idx_fp64()

test_embedding_bag_1()
test_embedding_bag_2()
test_embedding_bag_grad()
test_embedding_bag_functional()
test_embeddingbag_from_pretrained()
test_embeddingbag_from_pretrained_options()
test_embedding_bag_empty_input()



Loading…
Cancel
Save