#419 nn.MultiheadAttention

Merged
zoulq merged 10 commits from lzh_multihead into master 1 year ago
lzh commented 1 year ago
lzh changed title from [WIP]nn.MultiheadAttention to nn.MultiheadAttention 1 year ago
zoulq reviewed 1 year ago
msadapter/pytorch/nn/modules/activation.py
@@ -452,2 +424,2 @@
if attn_mask:
_attn_mask = self._process_mask(attn_mask, _batch_size)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
zoulq commented 1 year ago
用raise代替assert
lzh commented 1 year ago
已修改
zoulq reviewed 1 year ago
msadapter/pytorch/nn/modules/activation.py
@@ -454,0 +425,4 @@
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

if self._qkv_same_embed_dim is False:
self.q_proj_weight = Parameter(ms.ops.zeros((embed_dim, embed_dim), dtype=dtype))
zoulq commented 1 year ago
这里为啥用ms.ops接口?
lzh commented 1 year ago
已修改
zoulq reviewed 1 year ago
testing/ut/pytorch/nn/test_activation.py
@@ -511,3 +511,3 @@

#TODO: multiheadattention need reconstruct
'''
zoulq commented 1 year ago
可以直接参考https://github.com/pytorch/pytorch/blob/master/test/test_nn.py里的test_multihead_attn_add_zero_attn等用例进行功能完备性测试
lzh reviewed 1 year ago
@@ -510,8 +511,9 @@ def test_hardsigmoid():
assert ms_out.asnumpy().dtype == torch_out.numpy().dtype

#TODO: multiheadattention need reconstruct
lzh commented 1 year ago
补充了test/nn/test_multihead_attention.py里的用例
zoulq reviewed 1 year ago
msadapter/pytorch/nn/modules/activation.py
@@ -514,0 +441,4 @@
self.in_proj_bias = Parameter(empty(3 * embed_dim, dtype=dtype))
else:
self.in_proj_bias = None
# TODO: NonDynamicallyQuantizableLinear
zoulq commented 1 year ago
这个todo是什么功能?
lzh commented 1 year ago
![image](/attachments/ebaa1592-b05f-4290-8b0f-7a93eef9b80a)
lzh commented 1 year ago
NonDynamicallyQuantizableLinear是torch引入用来规避一个不常见error的,ms没有实现过,直接用普通Linear替代了
zoulq reviewed 1 year ago
msadapter/pytorch/nn/modules/activation.py
@@ -525,0 +521,4 @@
attn_output = attn_output.swapaxes(1, 0)
if need_weights:
return attn_output, attn_output_weights
return (attn_output,)
zoulq commented 1 year ago
如果上面调用的multi_head_attention_forward是mindspore接口,这里的attn_output是mindspore tensor吧?
lzh commented 1 year ago
Poster
supportedList里已有 ![image](/attachments/982f6f91-d6f9-4a5c-ad5e-7964510485e9)
zoulq merged commit aeae9881d5 into master 1 year ago
lzh deleted branch lzh_multihead 1 year ago
The pull request has been merged as aeae9881d5.
Sign in to join this conversation.
No reviewers
No Label
No Milestone
No Assignees
2 Participants
Notifications
Due Date

No due date set.

Dependencies

This pull request currently doesn't have any dependencies.

Loading…
There is no content yet.