|
- from typing import Type
-
- import torch.nn as nn
- import torch.optim
- from torch import Tensor
-
- from src.models.modules import ModuleType, _make_nn_module
-
-
- class AttentionBlock(nn.Module):
- """The attention block of `UNet`."""
-
- def __init__(
- self,
- *,
- d_embed: int,
- ) -> None:
- super().__init__()
-
- self.max = nn.AdaptiveMaxPool1d(1)
- self.avg = nn.AdaptiveAvgPool1d(1)
-
- self.mlp = nn.Linear(1, d_embed)
- self.sigmoid = nn.Sigmoid()
-
- self.spatial_conv = nn.Sequential(
- nn.Conv1d(2, 1, 1),
- nn.InstanceNorm1d(1),
- nn.PReLU(),
- )
-
- def forward(self, x: Tensor) -> Tensor:
- x_channel_max = self.mlp(self.max(x))
- x_channel_avg = self.mlp(self.avg(x))
- x_channel=self.sigmoid(x_channel_max+x_channel_avg)
-
- x_spatial = torch.transpose(x,1,2)
- x_spatial_max = self.max(x_spatial)
- x_spatial_avg = self.avg(x_spatial)
- x_spatial = torch.concat([x_spatial_avg,x_spatial_max],dim=2)
-
- x_spatial = torch.transpose(x_spatial, 1, 2)
- x_spatial = self.spatial_conv(x_spatial)
- x_spatial = self.sigmoid(x_spatial)
-
- output = x * x_spatial * x_channel
- return output
-
-
- if __name__ == "__main__":
- n = 32
- d_in = 18
- d_embed = 48
- d_out = 2
- n_layer = 4
- d_encode = 32
- activation = 'PReLU'
- normalization = 'InstanceNorm1d'
- model = AttentionBlock(
- d_embed=d_embed,
- )
- x = torch.randn(32, 18, d_embed)
- print(model(x).shape)
|