#419 nn.MultiheadAttention

Merged
zoulq merged 10 commits from lzh_multihead into master 1 year ago
  1. +55
    -2
      msadapter/pytorch/nn/functional.py
  2. +104
    -105
      msadapter/pytorch/nn/modules/activation.py
  3. +296
    -5
      testing/ut/pytorch/nn/test_activation.py

+ 55
- 2
msadapter/pytorch/nn/functional.py View File

@@ -3,7 +3,7 @@
"""Functional interface"""
import math
import warnings
from typing import Iterable
from typing import Iterable, Optional
from functools import lru_cache
import numpy as np
import mindspore as ms
@@ -101,7 +101,9 @@ all = [
'max_pool2d',

'fold',
'unfold'
'unfold',

'multi_head_attention_forward'
]

@constexpr
@@ -2313,3 +2315,54 @@ def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1):
if ndim == 2:
output = output.squeeze(0)
return cast_to_adapter_tensor(output)

def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
):
query = cast_to_ms_tensor(query)
key = cast_to_ms_tensor(key)
value = cast_to_ms_tensor(value)
key_padding_mask = cast_to_ms_tensor(key_padding_mask)
attn_mask = cast_to_ms_tensor(attn_mask)
static_k = cast_to_ms_tensor(static_k)
static_v = cast_to_ms_tensor(static_v)
# TODO: older ver of torch doesn't have is_causal arg
k_is_v = key is value
q_is_k = query is key
attn_output, attn_output_weights = ms.ops.function.nn_func.multi_head_attention_forward(
query, key, value, embed_dim_to_check, num_heads,
in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p,
out_proj_weight, out_proj_bias, training=training,
key_padding_mask=key_padding_mask, attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v,
average_attn_weights=average_attn_weights, k_is_v=k_is_v, q_is_k=q_is_k)
if need_weights:
output = attn_output, attn_output_weights
else:
output = attn_output, None
return cast_to_adapter_tensor(output)

+ 104
- 105
msadapter/pytorch/nn/modules/activation.py View File

@@ -4,17 +4,21 @@ import warnings
import numpy as np
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.function.nn_func import multi_head_attention_forward
from mindspore.common import dtype as mstype
import mindspore as ms
from mindspore import nn
from mindspore._checkparam import Validator as validator

from msadapter.pytorch.functional import empty
from msadapter.pytorch.nn.parameter import Parameter
import msadapter.pytorch.nn.functional as ms_torch_nn_func
from msadapter.pytorch.tensor import Tensor, tensor, cast_to_ms_tensor, cast_to_adapter_tensor
from msadapter.utils import unsupported_attr
from msadapter.pytorch.common._inner import _inplace_assign, _inplace_limit_pynative
from .module import Module
from .linear import Linear
from ..init import constant_, xavier_normal_, xavier_uniform_

__all__ = ['ReLU', 'Hardtanh', 'ReLU6', 'SiLU', 'Hardswish', 'LeakyReLU', 'Sigmoid', 'LogSigmoid', 'ELU', 'RReLU',
'SELU', 'CELU', 'GELU', 'Mish', 'Softshrink', 'Tanh', 'Tanhshrink','Threshold', 'Softmax', 'LogSoftmax',
@@ -406,122 +410,117 @@ class Hardsigmoid(Module):


class MultiheadAttention(Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, \
add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
super(MultiheadAttention, self).__init__()
if bias is not True:
raise ValueError(f"`bias` can only be set to 'True', but got {bias}")

if add_bias_kv:
raise ValueError(f"`add_bias_kv` can only be set to 'False', but got {add_bias_kv}")

if add_zero_attn:
raise ValueError(f"`add_zero_attn` can only be set to 'False', but got {add_zero_attn}")

unsupported_attr(kdim)
unsupported_attr(vdim)
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, device=None, dtype=None):
unsupported_attr(device)
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

self.num_heads = num_heads
self.dropout = dropout
self.add_bias_kv = add_bias_kv
self.add_zero_attn = add_zero_attn
self.kdim = kdim
self.vdim = vdim
self.batch_first = batch_first
self.dtype = dtype

self.reduce_mean = ms.ops.ReduceMean()

def forward(self, query, key, value, key_padding_mask=None,
need_weights: bool=True, attn_mask=None,
average_attn_weights: bool=True):
unsupported_attr(key_padding_mask)
unsupported_attr(average_attn_weights)
if need_weights is True:
raise ValueError("Until now, `need_weights`='True' is not supported")

query = self._batch_tensor(query, 'query')
key = self._batch_tensor(key, 'key')
value = self._batch_tensor(value, 'value')
_batch_size = query.shape[0]
_src_seq_length = query.shape[1]
_tgt_seq_length = key.shape[1]

if attn_mask:
_attn_mask = self._process_mask(attn_mask, _batch_size)
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")

if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(empty((embed_dim, embed_dim), dtype=dtype))
self.k_proj_weight = Parameter(empty((embed_dim, self.kdim), dtype=dtype))
self.v_proj_weight = Parameter(empty((embed_dim, self.vdim), dtype=dtype))
self.in_proj_weight = None
else:
_attn_mask = ms.ops.ones((_batch_size, _src_seq_length, _tgt_seq_length), mstype.float32)

self.ms_multihead_attention = ms.nn.transformer.MultiHeadAttention(
batch_size = _batch_size,
src_seq_length = _src_seq_length,
tgt_seq_length = _tgt_seq_length,
hidden_size=self.embed_dim,
num_heads=self.num_heads,
hidden_dropout_rate=self.dropout,
attention_dropout_rate=self.dropout,
compute_dtype=mstype.float32,
softmax_compute_type=mstype.float32,
param_init_type=mstype.float32,
use_past=False)
out, attn_output_weights = self.ms_multihead_attention(query, key, value, _attn_mask)

if not self.batch_first:
# ms default is (batch, seq, feature), batch_first
out = ms.ops.transpose(out, (1, 0, 2))

# if need_weights:
# if average_attn_weights:
# attn_output_weights = self.reduce_mean(attn_output_weights, 1)

# if _batch_size == 1:
# attn_output_weights = self.reduce_mean(attn_output_weights, 0)
# else:
# attn_output_weights = None

if _batch_size == 1:
out = self.reduce_mean(out, 0)

# TODO
# Until Now, attn_output_weights is not the same as pytorch
attn_output_weights = None
return cast_to_adapter_tensor(out), cast_to_adapter_tensor(attn_output_weights)

def _batch_tensor(self, x, x_name: str):
x = cast_to_ms_tensor(x)
_rank = ms.ops.rank(x)
if _rank == 2:
out = ms.ops.expand_dims(x, 0)
return out

if _rank == 3:
if not self.batch_first:
out = ms.ops.transpose(x, (1, 0 ,2))
else:
out = x
return out

raise ValueError(f"For MultiheadAttention, rank of {x_name} should be 2 or 3, but got {_rank}")
self.in_proj_weight = Parameter(empty((3 * embed_dim, embed_dim), dtype=dtype))
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None

def _process_mask(self, mask, batch_size):
mask = cast_to_ms_tensor(mask)
_rank = ms.ops.rank(mask)
if _rank == 2:
out = ms.ops.expand_dims(mask, 0)
return out
if bias:
self.in_proj_bias = Parameter(empty(3 * embed_dim, dtype=dtype))
else:
self.in_proj_bias = None
self.out_proj = Linear(embed_dim, embed_dim, bias=bias, dtype=dtype)

if _rank == 3:
if mask.shape[0] != batch_size:
warnings.warn("Until now, `attn_mask` can only support shape (N, L, S)"
"when `attn_mask` shape is (N * num_heads, L, S), pick the first (N, L, S) mask")
if add_bias_kv:
self.bias_k = Parameter(empty((1, 1, embed_dim), dtype=dtype))
self.bias_v = Parameter(empty((1, 1, embed_dim), dtype=dtype))
else:
self.bias_k = self.bias_v = None

mask = mask[:batch_size,:]
return mask
self.add_zero_attn = add_zero_attn

raise ValueError(f"For MultiheadAttention, rank of mask should be 2 or 3, but got {_rank}")
self._reset_parameters()

def _reset_parameters(self):
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)

if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)

def __setstate__(self, state):
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
if '_qkv_same_embed_dim' not in state:
state['_qkv_same_embed_dim'] = True

super(MultiheadAttention, self).__setstate__(state)

def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None,
average_attn_weights=True):
query = cast_to_ms_tensor(query)
key = cast_to_ms_tensor(key)
value = cast_to_ms_tensor(value)
key_padding_mask = cast_to_ms_tensor(key_padding_mask)
attn_mask = cast_to_ms_tensor(attn_mask)

is_batched = query.dim() == 3
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if key is value:
if query is key:
query = key = value = query.swapaxes(1, 0)
else:
query, key = [x.swapaxes(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)]

if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights)
else:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask, average_attn_weights=average_attn_weights)
if self.batch_first and is_batched:
attn_output = attn_output.swapaxes(1, 0)
if need_weights:
return cast_to_adapter_tensor(attn_output), cast_to_adapter_tensor(attn_output_weights)
return (cast_to_adapter_tensor(attn_output),)

class PReLU(Module):
def __init__(self, num_parameters=1, init=0.25, device=None, dtype=None):


+ 296
- 5
testing/ut/pytorch/nn/test_activation.py View File

@@ -8,6 +8,7 @@ import numpy as np
from mindspore import context
import mindspore as ms
import torch
import pytest

context.set_context(mode=ms.GRAPH_MODE)

@@ -510,8 +511,9 @@ def test_hardsigmoid():
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype

#TODO: multiheadattention need reconstruct
lzh commented 1 year ago
Review
补充了test/nn/test_multihead_attention.py里的用例
'''
def test_multi_head_attention1():
context.set_context(mode=ms.PYNATIVE_MODE)
_embed_dim = 20
_target_seq_length = 6
_batch_size = 5
@@ -538,6 +540,7 @@ def test_multi_head_attention1():
assert ms_output[0].shape == torch_output[0].shape

def test_multi_head_attention2():
context.set_context(mode=ms.PYNATIVE_MODE)
_embed_dim = 20
_target_seq_length = 6
_batch_size = 5
@@ -562,7 +565,7 @@ def test_multi_head_attention2():
ms_output = ms_net(ms_query, ms_key, ms_val, need_weights=False)

assert ms_output[0].shape == torch_output[0].shape
'''
def test_prelu():
input = np.array([[[[0.1, 0.6], [0.9, 0.9]]]]).astype(np.float32)
weight_init = 0.25
@@ -614,6 +617,283 @@ def test_softmax2d():
assert np.allclose(ms_output1.asnumpy(), torch_output1.numpy())
assert np.allclose(ms_output2.asnumpy(), torch_output2.numpy())

def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None,
average_attn_weights=False):
""" Numpy-based reference implementation of scaled dot attention
for testing"""

QKT = _batchmatmul(
Q,
np.transpose(K, axes=[0, 1, 3, 2])
/ np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head)
)
b1, b2, s1, s2 = QKT.shape
if unseen_mask is not None or key_padding_mask is not None:
# assert s1 == s2
for i in range(b1):
for j in range(b2):
for m in range(s1):
for n in range(s2):
if unseen_mask is not None and unseen_mask[m][n] == 0:
QKT[i, j, m, n] = -np.inf
if key_padding_mask is not None and key_padding_mask[i][n]:
QKT[i, j, m, n] = -np.inf

reference = _softmax(QKT)
ref_attn_weight = reference
if average_attn_weights:
ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2
reference = _batchmatmul(reference, V)
return reference, ref_attn_weight

def _batchmatmul(a, b): # batchmatmul over 4 dim matrix
""" Numpy-based batch matrix multiply over 4 dim matrix"""
assert a.shape[0] == b.shape[0]
assert a.shape[1] == b.shape[1]
retval = np.zeros(
(a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32
)
for i in range(a.shape[0]):
for j in range(a.shape[1]):
retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :])
return retval

def _softmax(x): # softmax over 4 dim matrix
""" Numpy-based reference softmax over 4 dim matrix"""
np.seterr(invalid='ignore')
output = np.zeros(x.shape, dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
x_curr = x[i, j, k, :]
e_x = np.exp(x_curr - np.amax(x_curr))
output[i, j, k, :] = e_x / np.sum(e_x)
return output

def _split_heads_ref(X, dims, nheads, d_head):
X_split = np.reshape(X, dims[:2] + [nheads, d_head])
X_split_transposed = np.transpose(X_split, [0, 2, 1, 3])
reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head])
return reference

def _combine_heads_ref(X, dims, nheads, d_head):
X_transposed = np.transpose(X, [0, 2, 1, 3])
reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head])
return reference

def _fc(X, X_weight, X_bias):
X_fc_b = X_bias.detach().numpy()
X_fc_w = X_weight.detach().numpy()
return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b

def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False,
saved_kv=False, same_embed_dim=False,
average_attn_weights=False):
for _ in range(5):
batch_sz, seq_len = [np.random.randint(2, 10) for r in range(2)]
d_head = np.random.randint(3, 10)
nheads = np.random.randint(2, 5) * 2
d_model = d_head * nheads
if same_embed_dim:
kv_dim = d_model
else:
kv_dim = np.random.randint(5, 20)
dims = [batch_sz, seq_len, kv_dim]

saved_k = None
saved_k_tensor = None
saved_v = None
saved_v_tensor = None
if saved_kv:
saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head)
saved_k_tensor = ms_torch.from_numpy(saved_k).to(ms_torch.float32)
saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head)
saved_v_tensor = ms_torch.from_numpy(saved_v).to(ms_torch.float32)

key_padding_mask = None
key_padding_mask_tensor = None
if add_key_padding_mask:
seq_mask = np.random.randint(0, 2, (1, seq_len))
key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1)
key_padding_mask_tensor = ms_torch.from_numpy(key_padding_mask)
decoder_state = np.random.rand(batch_sz, d_model)
K = np.random.rand(*dims)
V = K
Q = np.expand_dims(decoder_state, 1)
attn_mask = np.random.randint(0, 2, size=(1, seq_len))
attn_mask_tensor = ms_torch.from_numpy(attn_mask).float()
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()

decoder_state_tensor = ms_torch.from_numpy(decoder_state).to(ms_torch.float32)
source_hid_tensor = ms_torch.from_numpy(K).to(ms_torch.float32).transpose(0, 1)

multihead_attn_module = nn.MultiheadAttention(d_model, nheads,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
kdim=kv_dim, vdim=kv_dim)
if add_bias_kv:
bias_k = multihead_attn_module.bias_k.detach().numpy()
bias_v = multihead_attn_module.bias_v.detach().numpy()
else:
bias_k = None
bias_v = None

_Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1)
_V = source_hid_tensor
_K = source_hid_tensor

if multihead_attn_module._qkv_same_embed_dim:
result, result_weight = nn.multi_head_attention_forward(
_Q, _K, _V,
d_model, nheads,
multihead_attn_module.in_proj_weight, multihead_attn_module.in_proj_bias,
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
static_k=saved_k_tensor, static_v=saved_v_tensor,
average_attn_weights=average_attn_weights,
)
else:
result, result_weight = nn.multi_head_attention_forward(
_Q, _K, _V,
d_model, nheads,
None, multihead_attn_module.in_proj_bias,
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
True, multihead_attn_module.q_proj_weight,
multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight,
static_k=saved_k_tensor, static_v=saved_v_tensor,
average_attn_weights=average_attn_weights,
)

result = result.squeeze(0).detach().numpy()

if multihead_attn_module._qkv_same_embed_dim:
q_proj_weight = multihead_attn_module.in_proj_weight[:d_model]
k_proj_weight = multihead_attn_module.in_proj_weight[d_model:(d_model * 2)]
v_proj_weight = multihead_attn_module.in_proj_weight[(d_model * 2):]
else:
q_proj_weight = multihead_attn_module.q_proj_weight
k_proj_weight = multihead_attn_module.k_proj_weight
v_proj_weight = multihead_attn_module.v_proj_weight

Q_fc = _fc(Q, q_proj_weight, multihead_attn_module.in_proj_bias[:d_model])
K_fc = _fc(K, k_proj_weight, multihead_attn_module.in_proj_bias[d_model:(d_model * 2)])
V_fc = _fc(V, v_proj_weight, multihead_attn_module.in_proj_bias[(d_model * 2):])

if add_bias_kv:
K_fc = np.concatenate((K_fc, np.repeat(bias_k, K_fc.shape[0], axis=0)), axis=1)
V_fc = np.concatenate((V_fc, np.repeat(bias_v, V_fc.shape[0], axis=0)), axis=1)
if attn_mask is not None:
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)
if key_padding_mask is not None:
key_padding_mask = np.concatenate(
(key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1)
dims[1] += 1
Q_split = _split_heads_ref(
Q_fc, [batch_sz, 1, d_model], nheads, d_head
)

if saved_k is not None:
K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head])
else:
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)

if saved_v is not None:
V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head])
else:
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)

if add_zero_attn:
dims[1] += 1
K_split = np.concatenate(
(K_split, np.zeros([K_split.shape[0], K_split.shape[1], 1, K_split.shape[3]])), axis=2)
V_split = np.concatenate(
(V_split, np.zeros([V_split.shape[0], V_split.shape[1], 1, V_split.shape[3]])), axis=2)

if attn_mask is not None:
attn_mask = np.concatenate((attn_mask, np.ones([1, 1])), axis=1)

if key_padding_mask is not None:
key_padding_mask = np.concatenate(
(key_padding_mask, np.full((batch_sz, 1), False, dtype=bool)), axis=1)
# TODO: average_attn_weights set to both True and False
attn_heads, ref_attn_weight = _scaled_dot_attn_ref(
Q=Q_split,
K=K_split,
V=V_split,
dims=Q_split.shape,
unseen_mask=attn_mask,
key_padding_mask=key_padding_mask
)
combined_attn_heads = _combine_heads_ref(
X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
)

reference = _fc(combined_attn_heads, multihead_attn_module.out_proj.weight,
multihead_attn_module.out_proj.bias)
reference = np.squeeze(reference, axis=1)

# result = reference
# TODO: check if its' the same as self.assertEqual(tuple(result.shape), (batch_sz, d_model))
assert tuple(result.shape) == (batch_sz, d_model)
print("*********************** result ************************")
print(result)
print("*********************** reference ************************")
print(reference)
np.testing.assert_allclose(result, reference, atol=1e-5)

# result_weight = ref_attn_weight
result_weight = result_weight.detach().numpy()
assert tuple(result_weight.shape) == tuple(ref_attn_weight.shape)
np.testing.assert_allclose(result_weight, ref_attn_weight, atol=1e-5)

def test_multihead_attn_add_bias_kv():
_multihead_attn_test_helper(add_bias_kv=True)
# TODO: average_attn_weights set to both True and False

def test_multihead_attn_add_zero_attn():
_multihead_attn_test_helper(add_zero_attn=True)

def test_multihead_attn_no_masking():
_multihead_attn_test_helper()

def test_multihead_attn_key_padding_mask():
_multihead_attn_test_helper(add_key_padding_mask=False)

def test_multihead_attn_saved_kv():
_multihead_attn_test_helper(saved_kv=True)

def test_multihead_attn_add_bias_kv_zero_attn():
_multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=True,
add_zero_attn=True)

def test_multihead_attn_all_arguments1():
_multihead_attn_test_helper(add_key_padding_mask=False, add_zero_attn=True, saved_kv=True)

def test_multihead_attn_all_arguments2():
# expected to raise error: The bias_k cannot be added to static_k
with pytest.raises(ValueError):
_multihead_attn_test_helper(add_key_padding_mask=True, add_bias_kv=True,
add_zero_attn=True, saved_kv=True)

def test_multihead_attn_all_arguments3():
_multihead_attn_test_helper(add_key_padding_mask=False, add_zero_attn=True,
saved_kv=True, same_embed_dim=True)

def test_multihead_attn_no_bias():
embed_dim = 8
num_heads = 4
mha = nn.MultiheadAttention(embed_dim, num_heads, bias=False)

# Verify that bias=False applies to both in and out projection layers.
assert mha.in_proj_bias is None
assert mha.out_proj.bias is None

if __name__ == '__main__':
test_relu1()
@@ -648,9 +928,20 @@ if __name__ == '__main__':
test_softsign()
test_glu()
test_hardshrink()
#test_multi_head_attention1()
#test_multi_head_attention2()
test_multi_head_attention1()
test_multi_head_attention2()
test_prelu()
test_softplus()
test_softmax2d()
test_prelu_grad()
test_prelu_grad()
test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
# TODO: add_key_padding_mask to be set to True after ms bug fixed
test_multihead_attn_key_padding_mask() # Test MultiheadAttention with src lengths
test_multihead_attn_saved_kv() # Test MultiheadAttention with static kv.
test_multihead_attn_add_bias_kv_zero_attn() # Test MultiheadAttention with bias_kv and zero_attn.
test_multihead_attn_all_arguments1() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument.
test_multihead_attn_no_bias()

Loading…
Cancel
Save