|
- import torch
- import copy
- from torch.nn import functional as F
- from torch.nn.modules.module import Module
- from torch.nn.modules.activation import MultiheadAttention
- from torch.nn.modules.container import ModuleList
- from torch.nn.init import xavier_uniform_
- from torch.nn.modules.dropout import Dropout
- from torch.nn.modules.linear import Linear
- from torch.nn.modules.rnn import LSTM
- from torch.nn.modules.normalization import LayerNorm
-
-
- class TransformerEncoderLayer(Module):
-
- def __init__(self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"):
- super(TransformerEncoderLayer, self).__init__()
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
-
- # Implementation of improved part
- self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
- self.dropout = Dropout(dropout)
- self.linear = Linear(hidden_size*2, d_model)
-
- self.norm1 = LayerNorm(d_model)
- self.norm2 = LayerNorm(d_model)
- self.dropout1 = Dropout(dropout)
- self.dropout2 = Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def __setstate__(self, state):
- if 'activation' not in state:
- state['activation'] = F.relu
- super(TransformerEncoderLayer, self).__setstate__(state)
-
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
- # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
- r"""Pass the input through the encoder layer.
- Args:
- src: the sequnce to the encoder layer (required).
- src_mask: the mask for the src sequence (optional).
- src_key_padding_mask: the mask for the src keys per batch (optional).
- Shape:
- see the docs in Transformer class.
- """
- src2 = self.self_attn(src, src, src, attn_mask=src_mask,
- key_padding_mask=src_key_padding_mask)[0]
- src = src + self.dropout1(src2)
- src = self.norm1(src)
- src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
- src = src + self.dropout2(src2)
- src = self.norm2(src)
- return src
-
-
- def _get_clones(module, N):
- return ModuleList([copy.deepcopy(module) for i in range(N)])
-
-
- def _get_activation_fn(activation):
- if activation == "relu":
- return F.relu
- elif activation == "gelu":
- return F.gelu
-
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|