|
- from typing import Type
-
- import torch.nn as nn
- import torch.optim
- from torch import Tensor
-
- from src.models.modules import ModuleType, _make_nn_module
-
-
- class UNet(nn.Module):
- class EncoderBlock(nn.Module):
- """The encoder block of `UNet`."""
-
- def __init__(
- self,
- *,
- d_in: int,
- d_out: int,
- kernel_size: int,
- normalization: ModuleType,
- activation: ModuleType,
- skip_connection: bool
- ) -> None:
- super().__init__()
- self.conv = nn.Conv1d(d_in, d_out, kernel_size,padding=kernel_size//2,padding_mode='replicate')
- self.normalization = _make_nn_module(normalization, d_out)
- self.activation = _make_nn_module(activation)
- self.skip_connection = skip_connection
-
- def forward(self, x: Tensor) -> Tensor:
- x_input = x
- x = self.conv(x)
- x = self.normalization(x)
- x = self.activation(x)
- if self.skip_connection:
- x = x_input + x
- return x
-
- class DecoderBlock(nn.Module):
- """The decoder block of `UNet`."""
-
- def __init__(
- self,
- *,
- d_in: int,
- d_out: int,
- kernel_size: int,
- normalization: ModuleType,
- activation: ModuleType,
- ) -> None:
- super().__init__()
- self.conv = nn.ConvTranspose1d(d_in, d_out, kernel_size,padding=kernel_size//2)
- self.normalization = _make_nn_module(normalization, d_out)
- self.activation = _make_nn_module(activation)
-
- def forward(self, x_enc: Tensor, x_dec: Tensor) -> Tensor:
- x = torch.cat([x_enc, x_dec], dim=1)
- x = self.conv(x)
- x = self.normalization(x)
- x = self.activation(x)
- return x
-
- def __init__(
- self,
- *,
- d_in: int,
- d_embed: int,
- d_out: int,
- n_layer:int,
- d_encode: int,
- normalization: ModuleType,
- activation: ModuleType
- ) -> None:
- """
- Note:
- `make_baseline` is the recommended constructor.
- """
- super().__init__()
- self.embed_layer = nn.Linear(d_in, d_in * d_embed)
- self.d_in = d_in
- self.d_out = d_out
- self.d_embed = d_embed
- self.n_layer=n_layer
-
- dims = [d_in]
- for i in range(self.n_layer):
- dims.append(d_encode*2**(i+1))
-
- self.encoder = []
- self.decoder = []
-
- dims = [d_in, 64, 128, 256, 512, 1024] # ,32
- self.enc1 = UNet.EncoderBlock(d_in=dims[0], d_out=dims[1],
- kernel_size=7, normalization=normalization,
- activation=activation, skip_connection=skip_connection)
- self.enc2 = UNet.EncoderBlock(d_in=dims[1], d_out=dims[2],
- kernel_size=5, normalization=normalization,
- activation=activation, skip_connection=skip_connection)
- self.enc3 = UNet.EncoderBlock(d_in=dims[2], d_out=dims[3],
- kernel_size=5, normalization=normalization,
- activation=activation, skip_connection=skip_connection)
- self.enc4 = UNet.EncoderBlock(d_in=dims[3], d_out=dims[4],
- kernel_size=3, normalization=normalization,
- activation=activation, skip_connection=skip_connection)
- self.dec3 = UNet.DecoderBlock(d_in=dims[3] + dims[4], d_out=dims[3],
- kernel_size=3, normalization=normalization,
- activation=activation)
- self.dec2 = UNet.DecoderBlock(d_in=dims[2] + dims[3], d_out=dims[2],
- kernel_size=3, normalization=normalization,
- activation=activation)
- self.dec1 = UNet.DecoderBlock(d_in=dims[1] + dims[2], d_out=dims[1],
- kernel_size=3, normalization=normalization,
- activation=activation)
-
- self.output_head = nn.Conv1d(dims[1], d_out, 3,padding=1,padding_mode='replicate')
- self.output = nn.Linear(d_out * d_embed, d_out)
-
- @classmethod
- def make_baseline(
- cls: Type['UNet'],
- *,
- d_in: int,
- d_embed: int,
- d_out: int,
- n_layer:int,
- d_encode: int,
- ) -> 'UNet':
- return cls(
- d_in=d_in,
- d_embed=d_embed,
- d_out=d_out,
- n_layer=n_layer,
- d_encode=d_encode,
- normalization='InstanceNorm1d',
- activation='ReLU'
- )
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.embed_layer(x)
- n,_ = x.size()
- x = torch.reshape(x,(n,self.d_in,self.d_embed))
-
- x_enc1 = self.enc1(x)
- x_enc2 = self.enc2(x_enc1)
- x_enc3 = self.enc3(x_enc2)
- x_enc4 = self.enc4(x_enc3)
-
- x_dec = self.dec3(x_enc3, x_enc4)
- x_dec = self.dec2(x_enc2, x_dec)
- x_dec = self.dec1(x_enc1, x_dec)
- x = self.output_head(x_dec)
-
- x = torch.reshape(x, (n, self.d_out*self.d_embed))
- x = self.output(x)
- return x
-
-
- if __name__ == "__main__":
- n = 32
- d_in = 18
- d_embed = 48
- d_out = 2
- n_layer = 4
- d_encode = 32
- activation = 'ReLU'
- normalization = 'InstanceNorm1d'
- model = UNet(
- d_in=d_in,
- d_embed=d_embed,
- d_out=d_out,
- n_layer=n_layer,
- d_encode=d_encode,
- normalization=normalization,
- activation=activation
- )
- x = torch.randn(32, 18)
- print(model(x).shape)
|