|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/3/25 10:38
- # Author:DFTL
- # @File:Cross_Attention.py
- '''
-
- import torch.nn as nn
- import torch
- import torch.nn.functional as F
-
- from einops import rearrange, repeat
-
- class Cross_Attention(nn.Module):
- def __init__(self,in_dim,out_dim,num_class = 12):
- super(Cross_Attention, self).__init__()
-
- self.in_dim = in_dim
- self.out_dim = out_dim
- self.query = nn.Linear(in_dim, out_dim, bias=False)
- self.key = nn.Linear(in_dim, out_dim, bias=False)
- self.value = nn.Linear(in_dim, out_dim, bias=False)
-
- # self.layer_norm = nn.LayerNorm([bs])
-
- self.linear = nn.Linear(in_dim, out_dim, bias=False)
-
-
- def forward(self,image,event):
-
- # event->Q image->K、V
- Q = self.query(event)
- K = self.key(image)
- V = self.value(image)
-
- # Softmax(Q*K.T)*V
- att_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.out_dim ** 0.5)
-
- attn_weights = F.softmax(att_weights, dim=-1)
-
- # output = torch.bmm(attn_weights, V)
- output = attn_weights @ V
-
- output = self.linear(F.normalize(output,dim=1))
-
- return output
-
- class Cross_Attention_v2(nn.Module): #+res
- def __init__(self,in_dim,out_dim,num_class = 12):
- super(Cross_Attention_v2, self).__init__()
-
- self.in_dim = in_dim
- self.out_dim = out_dim
- self.query = nn.Linear(in_dim, out_dim, bias=False)
- self.key = nn.Linear(in_dim, out_dim, bias=False)
- self.value = nn.Linear(in_dim, out_dim, bias=False)
-
- # self.layer_norm = nn.LayerNorm([bs])
-
- self.linear = nn.Linear(in_dim, num_class, bias=False)
-
-
- def forward(self,image,event):
-
- # event->Q image->K、V
- Q = self.query(event)
- K = self.key(image)
- V = self.value(image)
-
- # Softmax(Q*K.T)*V
- att_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.out_dim ** 0.5)
-
- attn_weights = F.softmax(att_weights, dim=-1)
-
- # output = torch.bmm(attn_weights, V)
- output = attn_weights @ V + image
-
- output = self.linear(F.normalize(output,dim=1))
-
- return output
-
-
- class Cross_MultiAttention(nn.Module):
- def __init__(self, in_channels, emb_dim, num_heads, att_dropout=0.0, aropout=0.0):
- super(Cross_MultiAttention, self).__init__()
- self.emb_dim = emb_dim
- self.num_heads = num_heads
- self.scale = emb_dim ** -0.5
-
- assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
- self.depth = emb_dim // num_heads
-
- self.proj_in = nn.Conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)
-
- self.Wq = nn.Linear(emb_dim, emb_dim)
- self.Wk = nn.Linear(emb_dim, emb_dim)
- self.Wv = nn.Linear(emb_dim, emb_dim)
-
- self.proj_out = nn.Conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x, context, pad_mask=None):
- '''
-
- :param x: [batch_size, c, h, w]
- :param context: [batch_szie, seq_len, emb_dim]
- :param pad_mask: [batch_size, seq_len, seq_len]
- :return:
- '''
- batch_size, c, h, w = x.shape
-
- x = self.proj_in(x) # [batch_size, c, h, w] = [3, 512, 512, 512]
- x = rearrange(x, 'b c h w -> b (h w) c') # [batch_size, h*w, c] = [3, 262144, 512]
-
- Q = self.Wq(x) # [batch_size, h*w, emb_dim] = [3, 262144, 512]
- K = self.Wk(context) # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
- V = self.Wv(context)
-
- Q = Q.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2) # [batch_size, num_heads, h*w, depth]
- K = K.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2) # [batch_size, num_heads, seq_len, depth]
- V = V.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
-
- # [batch_size, num_heads, h*w, seq_len]
- att_weights = torch.einsum('bnid,bnjd -> bnij', Q, K)
- att_weights = att_weights * self.scale
-
- if pad_mask is not None:
- # 因为是多头,所以mask矩阵维度要扩充到4维 [batch_size, h*w, seq_len] -> [batch_size, nums_head, h*w, seq_len]
- pad_mask = pad_mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
- att_weights = att_weights.masked_fill(pad_mask, -1e9)
-
- att_weights = F.softmax(att_weights, dim=-1)
- out = torch.einsum('bnij, bnjd -> bnid', att_weights, V)
- out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.emb_dim) # [batch_size, h*w, emb_dim]
-
- out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) # [batch_size, c, h, w]
- out = self.proj_out(out) # [batch_size, c, h, w]
-
- return out, att_weights
-
-
- if __name__ == "__main__":
- image = torch.randn((8,64))
- event = torch.randn((8,64))
-
- cross_att = Cross_Attention(64,64)
- out = cross_att(image,event)
|