From cf251a87b98baaba0422c226ee3cfb7989eaa0b7 Mon Sep 17 00:00:00 2001 From: yangge Date: Thu, 23 Mar 2023 22:21:23 +0800 Subject: [PATCH] conformer premature code without training adoption of the MSA --- official/nlp/conformer/conformer-msa/LICENSE | 201 ++++++++++ official/nlp/conformer/conformer-msa/model.py | 372 ++++++++++++++++++ official/nlp/conformer/conformer-msa/train.py | 276 +++++++++++++ official/nlp/conformer/conformer-msa/utils.py | 223 +++++++++++ .../nlp/conformer/conformer-pytorch/LICENSE | 201 ++++++++++ .../nlp/conformer/conformer-pytorch/README.md | 32 ++ .../nlp/conformer/conformer-pytorch/model.py | 368 +++++++++++++++++ .../nlp/conformer/conformer-pytorch/train.py | 252 ++++++++++++ .../nlp/conformer/conformer-pytorch/utils.py | 221 +++++++++++ 9 files changed, 2146 insertions(+) create mode 100644 official/nlp/conformer/conformer-msa/LICENSE create mode 100644 official/nlp/conformer/conformer-msa/model.py create mode 100644 official/nlp/conformer/conformer-msa/train.py create mode 100644 official/nlp/conformer/conformer-msa/utils.py create mode 100644 official/nlp/conformer/conformer-pytorch/LICENSE create mode 100644 official/nlp/conformer/conformer-pytorch/README.md create mode 100644 official/nlp/conformer/conformer-pytorch/model.py create mode 100644 official/nlp/conformer/conformer-pytorch/train.py create mode 100644 official/nlp/conformer/conformer-pytorch/utils.py diff --git a/official/nlp/conformer/conformer-msa/LICENSE b/official/nlp/conformer/conformer-msa/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/official/nlp/conformer/conformer-msa/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/official/nlp/conformer/conformer-msa/model.py b/official/nlp/conformer/conformer-msa/model.py new file mode 100644 index 0000000..bdf519a --- /dev/null +++ b/official/nlp/conformer/conformer-msa/model.py @@ -0,0 +1,372 @@ +import math +# import torch +# from torch import nn +# import torch.nn.functional as F +import mindspore as ms +import ms_adapter.pytorch as torch +import ms_adapter.pytorch.nn as nn +import ms_adapter.pytorch.nn.functional as F + +class PositionalEncoder(nn.Module): + ''' + Generate positional encodings used in the relative multi-head attention module. + These encodings are the same as the original transformer model: https://arxiv.org/abs/1706.03762 + + Parameters: + max_len (int): Maximum sequence length (time dimension) + + Inputs: + len (int): Length of encodings to retrieve + + Outputs + Tensor (len, d_model): Positional encodings + ''' + def __init__(self, d_model, max_len=10000): + super(PositionalEncoder, self).__init__() + self.d_model = d_model + encodings = torch.zeros(max_len, d_model) + pos = torch.arange(0, max_len, dtype=torch.float32) + inv_freq = 1 / (10000 ** (torch.arange(0.0, d_model, 2.0) / d_model)) + encodings[:, 0::2] = torch.sin(pos[:, None] * inv_freq) + encodings[:, 1::2] = torch.cos(pos[:, None] * inv_freq) + self.register_buffer('encodings', encodings) + + def forward(self, len): + return self.encodings[:len, :] + +class RelativeMultiHeadAttention(nn.Module): + ''' + Relative Multi-Head Self-Attention Module. + Method proposed in Transformer-XL paper: https://arxiv.org/abs/1901.02860 + + Parameters: + d_model (int): Dimension of the model + num_heads (int): Number of heads to split inputs into + dropout (float): Dropout probability + positional_encoder (nn.Module): PositionalEncoder module + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the attention module. + + ''' + def __init__(self, d_model=144, num_heads=4, dropout=0.1, positional_encoder=PositionalEncoder(144)): + super(RelativeMultiHeadAttention, self).__init__() + + #dimensions + assert d_model % num_heads == 0 + self.d_model = d_model + self.d_head = d_model // num_heads + self.num_heads = num_heads + + # Linear projection weights + self.W_q = nn.Linear(d_model, d_model) + self.W_k = nn.Linear(d_model, d_model) + self.W_v = nn.Linear(d_model, d_model) + self.W_pos = nn.Linear(d_model, d_model, bias=False) + self.W_out = nn.Linear(d_model, d_model) + + # Trainable bias parameters + self.u = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u) + torch.nn.init.xavier_uniform_(self.v) + + # etc + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + self.positional_encoder = positional_encoder + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + batch_size, seq_length, _ = x.size() + + #layer norm and pos embeddings + x = self.layer_norm(x) + pos_emb = self.positional_encoder(seq_length) + pos_emb = pos_emb.repeat(batch_size, 1, 1) + + #Linear projections, split into heads + q = self.W_q(x).view(batch_size, seq_length, self.num_heads, self.d_head) + k = self.W_k(x).view(batch_size, seq_length, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + v = self.W_v(x).view(batch_size, seq_length, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + pos_emb = self.W_pos(pos_emb).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + + #Compute attention scores with relative position embeddings + AC = torch.matmul((q + self.u).transpose(1, 2), k) + BD = torch.matmul((q + self.v).transpose(1, 2), pos_emb) + BD = self.rel_shift(BD) + attn = (AC + BD) / math.sqrt(self.d_model) + + #Mask before softmax with large negative number + if mask is not None: + mask = mask.unsqueeze(1) + mask_value = -1e+30 if attn.dtype == torch.float32 else -1e+4 + attn.masked_fill_(mask, mask_value) + + #Softmax + attn = F.softmax(attn, -1) + + #Construct outputs from values + output = torch.matmul(attn, v.transpose(2, 3)).transpose(1, 2) # (batch_size, time, num_heads, d_head) + output = output.contiguous().view(batch_size, -1, self.d_model) # (batch_size, time, d_model) + + #Output projections and dropout + output = self.W_out(output) + return self.dropout(output) + + + def rel_shift(self, emb): + ''' + Pad and shift form relative positional encodings. + Taken from Transformer-XL implementation: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py + ''' + batch_size, num_heads, seq_length1, seq_length2 = emb.size() + zeros = emb.new_zeros(batch_size, num_heads, seq_length1, 1) + padded_emb = torch.cat([zeros, emb], dim=-1) + padded_emb = padded_emb.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + shifted_emb = padded_emb[:, :, 1:].view_as(emb) + return shifted_emb + + +class ConvBlock(nn.Module): + ''' + Conformer convolutional block. + + Parameters: + d_model (int): Dimension of the model + kernel_size (int): Size of kernel to use for depthwise convolution + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask: Unused + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the convolution module + + ''' + def __init__(self, d_model=144, kernel_size=31, dropout=0.1): + super(ConvBlock, self).__init__() + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + kernel_size=31 + self.module = nn.Sequential( + nn.Conv1d(in_channels=d_model, out_channels=d_model * 2, kernel_size=1), # first pointwise with 2x expansion + nn.GLU(dim=1), + nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=kernel_size, padding='same', groups=d_model), # depthwise + nn.BatchNorm1d(d_model, eps=6.1e-5), + nn.SiLU(), # swish activation + nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=1), # second pointwise + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.layer_norm(x) + x = x.transpose(1, 2) # (batch_size, d_model, seq_len) + x = self.module(x) + return x.transpose(1, 2) + +class FeedForwardBlock(nn.Module): + ''' + Conformer feed-forward block. + + Parameters: + d_model (int): Dimension of the model + expansion (int): Expansion factor for first linear layer + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask: Unused + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the feed-forward module + + ''' + def __init__(self, d_model=144, expansion=4, dropout=0.1): + super(FeedForwardBlock, self).__init__() + self.module = nn.Sequential( + nn.LayerNorm(d_model, eps=6.1e-5), + nn.Linear(d_model, d_model * expansion), # expand to d_model * expansion + nn.SiLU(), # swish activation + nn.Dropout(dropout), + nn.Linear(d_model * expansion, d_model), # project back to d_model + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.module(x) + +class Conv2dSubsampling(nn.Module): + ''' + 2d Convolutional subsampling. + Subsamples time and freq domains of input spectrograms by a factor of 4, d_model times. + + Parameters: + d_model (int): Dimension of the model + + Inputs: + x (Tensor): Input spectrogram (batch_size, time, d_input) + + Outputs: + Tensor (batch_size, time, d_model * (d_input // 4)): Output tensor from the conlutional subsampling module + + ''' + def __init__(self, d_model=144): + super(Conv2dSubsampling, self).__init__() + self.module = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=d_model, kernel_size=3, stride=2), + nn.ReLU(), + nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=3, stride=2), + nn.ReLU(), + ) + + def forward(self, x): + output = self.module(x.unsqueeze(1)) # (batch_size, 1, time, d_input) + batch_size, d_model, subsampled_time, subsampled_freq = output.size() + output = output.permute(0, 2, 1, 3) + output = output.contiguous().view(batch_size, subsampled_time, d_model * subsampled_freq) + return output + +class ConformerBlock(nn.Module): + ''' + Conformer Encoder Block. + + Parameters: + d_model (int): Dimension of the model + conv_kernel_size (int): Size of kernel to use for depthwise convolution + feed_forward_residual_factor (float): output_weight for feed-forward residual connections + feed_forward_expansion_factor (int): Expansion factor for feed-forward block + num_heads (int): Number of heads to use for multi-head attention + positional_encoder (nn.Module): PositionalEncoder module + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the conformer block. + + ''' + def __init__( + self, + d_model=144, + conv_kernel_size=31, + feed_forward_residual_factor=.5, + feed_forward_expansion_factor=4, + num_heads=4, + positional_encoder=PositionalEncoder(144), + dropout=0.1, + ): + super(ConformerBlock, self).__init__() + self.residual_factor = feed_forward_residual_factor + self.ff1 = FeedForwardBlock(d_model, feed_forward_expansion_factor, dropout) + self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout, positional_encoder) + self.conv_block = ConvBlock(d_model, conv_kernel_size, dropout) + self.ff2 = FeedForwardBlock(d_model, feed_forward_expansion_factor, dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + + def forward(self, x, mask=None): + x = x + (self.residual_factor * self.ff1(x)) + x = x + self.attention(x, mask=mask) + x = x + self.conv_block(x) + x = x + (self.residual_factor * self.ff2(x)) + return self.layer_norm(x) + + +class ConformerEncoder(nn.Module): + ''' + Conformer Encoder Module. + + Parameters: + d_input (int): Dimension of the input + d_model (int): Dimension of the model + num_layers (int): Number of conformer blocks to use in the encoder + conv_kernel_size (int): Size of kernel to use for depthwise convolution + feed_forward_residual_factor (float): output_weight for feed-forward residual connections + feed_forward_expansion_factor (int): Expansion factor for feed-forward block + num_heads (int): Number of heads to use for multi-head attention + dropout (float): Dropout probability + + Inputs: + x (Tensor): input spectrogram of dimension (batch_size, time, d_input) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the conformer encoder + + + ''' + def __init__( + self, + d_input=80, + d_model=144, + num_layers=16, + conv_kernel_size=31, + feed_forward_residual_factor=.5, + feed_forward_expansion_factor=4, + num_heads=4, + dropout=.1, + ): + super(ConformerEncoder, self).__init__() + self.conv_subsample = Conv2dSubsampling(d_model=d_model) + self.linear_proj = nn.Linear(d_model * (((d_input - 1) // 2 - 1) // 2), d_model) # project subsamples to d_model + self.dropout = nn.Dropout(p=dropout) + + # define global positional encoder to limit model parameters + positional_encoder = PositionalEncoder(d_model) + self.layers = nn.ModuleList([ConformerBlock( + d_model=d_model, + conv_kernel_size=conv_kernel_size, + feed_forward_residual_factor=feed_forward_residual_factor, + feed_forward_expansion_factor=feed_forward_expansion_factor, + num_heads=num_heads, + positional_encoder=positional_encoder, + dropout=dropout, + ) for _ in range(num_layers)]) + + def forward(self, x, mask=None): + x = self.conv_subsample(x) + if mask is not None: + mask = mask[:, :-2:2, :-2:2] #account for subsampling + mask = mask[:, :-2:2, :-2:2] #account for subsampling + assert mask.shape[1] == x.shape[1], f'{mask.shape} {x.shape}' + + x = self.linear_proj(x) + x = self.dropout(x) + + for layer in self.layers: + x = layer(x, mask=mask) + + return x + + +class LSTMDecoder(nn.Module): + ''' + LSTM Decoder + + Parameters: + d_encoder (int): Output dimension of the encoder + d_decoder (int): Hidden dimension of the decoder + num_layers (int): Number of LSTM layers to use in the decoder + num_classes (int): Number of output classes to predict + + Inputs: + x (Tensor): (batch_size, time, d_encoder) + + Outputs: + Tensor (batch_size, time, num_classes): Class prediction logits + + ''' + def __init__(self, d_encoder=144, d_decoder=320, num_layers=1, num_classes=29): + super(LSTMDecoder, self).__init__() + self.lstm = nn.LSTM(input_size=d_encoder, hidden_size=d_decoder, num_layers=num_layers, batch_first=True) + self.linear = nn.Linear(d_decoder, num_classes) + + def forward(self, x): + x, _ = self.lstm(x) + logits = self.linear(x) + return logits diff --git a/official/nlp/conformer/conformer-msa/train.py b/official/nlp/conformer/conformer-msa/train.py new file mode 100644 index 0000000..90cc283 --- /dev/null +++ b/official/nlp/conformer/conformer-msa/train.py @@ -0,0 +1,276 @@ +import os +import gc +import argparse +import torchaudio + +# import torch +# from torch import nn +# import torch.nn.functional as F +# from torch.utils.data import DataLoader + +import ms_adapter.pytorch.nn as nn +import ms_adapter.pytorch.nn.functional as F +from ms_adapter.pytorch.utils.data import DataLoader +import mindspore as ms + +ms.set_context(device_target='CPU') + +from torchmetrics.text.wer import WordErrorRate +# from torch.cuda.amp import autocast, GradScaler +from model import ConformerEncoder, LSTMDecoder +from utils import * +import ms_adapter.pytorch as torch + + +parser = argparse.ArgumentParser("conformer") +parser.add_argument('--data_dir', type=str, default='./data', help='location to download data') +parser.add_argument('--checkpoint_path', type=str, default='model_best.pt', help='path to store/load checkpoints') +parser.add_argument('--load_checkpoint', action='store_true', default=False, help='resume training from checkpoint') +parser.add_argument('--train_set', type=str, default='train-clean-100', help='train dataset') +parser.add_argument('--test_set', type=str, default='test-clean', help='test dataset') +parser.add_argument('--batch_size', type=int, default=1, help='batch size') +parser.add_argument('--warmup_steps', type=float, default=10000, help='Multiply by sqrt(d_model) to get max_lr') +parser.add_argument('--peak_lr_ratio', type=int, default=0.05, help='Number of warmup steps for LR scheduler') +parser.add_argument('--gpu', type=int, default=0, help='gpu device id (optional)') +parser.add_argument('--epochs', type=int, default=50, help='num of training epochs') +parser.add_argument('--report_freq', type=int, default=100, help='training objective report frequency') +parser.add_argument('--layers', type=int, default=8, help='total number of layers') +parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') +parser.add_argument('--use_amp', action='store_true', default=False, help='use mixed precision to train') +parser.add_argument('--attention_heads', type=int, default=4, help='number of heads to use for multi-head attention') +parser.add_argument('--d_input', type=int, default=80, help='dimension of the input (num filter banks)') +# parser.add_argument('--d_encoder', type=int, default=144, help='dimension of the encoder') +parser.add_argument('--d_encoder', type=int, default=20, help='dimension of the encoder') +# parser.add_argument('--d_decoder', type=int, default=320, help='dimension of the decoder') +parser.add_argument('--d_decoder', type=int, default=80, help='dimension of the decoder') +parser.add_argument('--encoder_layers', type=int, default=16, help='number of conformer blocks in the encoder') +parser.add_argument('--decoder_layers', type=int, default=1, help='number of decoder layers') +parser.add_argument('--conv_kernel_size', type=int, default=31, help='size of kernel for conformer convolution blocks') +parser.add_argument('--feed_forward_expansion_factor', type=int, default=4, help='expansion factor for conformer feed forward blocks') +parser.add_argument('--feed_forward_residual_factor', type=int, default=.5, help='residual factor for conformer feed forward blocks') +parser.add_argument('--dropout', type=float, default=.1, help='dropout factor for conformer model') +parser.add_argument('--weight_decay', type=float, default=1e-6, help='model weight decay (corresponds to L2 regularization)') +parser.add_argument('--variational_noise_std', type=float, default=.0001, help='std of noise added to model weights for regularization') +parser.add_argument('--num_workers', type=int, default=2, help='num_workers for the dataloader') +parser.add_argument('--smart_batch', type=bool, default=True, help='Use smart batching for faster training') +parser.add_argument('--accumulate_iters', type=int, default=1, help='Number of iterations to accumulate gradients') +args = parser.parse_args() + +def main(): + # Load Data + if not os.path.isdir(args.data_dir): + os.mkdir(args.data_dir) + train_data = torchaudio.datasets.LIBRISPEECH(root=args.data_dir, url=args.train_set) + test_data = torchaudio.datasets.LIBRISPEECH(args.data_dir, url=args.test_set) + + if args.smart_batch: + print('Sorting training data for smart batching...') + sorted_train_inds = [ind for ind, _ in sorted(enumerate(train_data), key=lambda x: x[1][0].shape[1])] + sorted_test_inds = [ind for ind, _ in sorted(enumerate(test_data), key=lambda x: x[1][0].shape[1])] + train_loader = DataLoader(dataset=train_data, + pin_memory=True, + num_workers=args.num_workers, + batch_sampler=BatchSampler(sorted_train_inds, batch_size=args.batch_size), + collate_fn=lambda x: preprocess_example(x, 'train')) + + test_loader = DataLoader(dataset=test_data, + pin_memory=True, + num_workers=args.num_workers, + batch_sampler=BatchSampler(sorted_test_inds, batch_size=args.batch_size), + collate_fn=lambda x: preprocess_example(x, 'valid')) + else: + train_loader = DataLoader(dataset=train_data, + pin_memory=True, + num_workers=args.num_workers, + batch_size=args.batch_size, + shuffle=True, + collate_fn=lambda x: preprocess_example(x, 'train')) + + test_loader = DataLoader(dataset=test_data, + pin_memory=True, + num_workers=args.num_workers, + batch_size=args.batch_size, + shuffle=False, + collate_fn=lambda x: preprocess_example(x, 'valid')) + + + # Declare Models + + encoder = ConformerEncoder( + d_input=args.d_input, + d_model=args.d_encoder, + num_layers=args.encoder_layers, + conv_kernel_size=args.conv_kernel_size, + dropout=args.dropout, + feed_forward_residual_factor=args.feed_forward_residual_factor, + feed_forward_expansion_factor=args.feed_forward_expansion_factor, + num_heads=args.attention_heads) + + decoder = LSTMDecoder( + d_encoder=args.d_encoder, + d_decoder=args.d_decoder, + num_layers=args.decoder_layers) + char_decoder = GreedyCharacterDecoder().eval() + criterion = nn.CTCLoss(blank=28, zero_infinity=True) + optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=5e-4, betas=(.9, .98), eps=1e-05 if args.use_amp else 1e-09, weight_decay=args.weight_decay) + scheduler = TransformerLrScheduler(optimizer, args.d_encoder, args.warmup_steps) + + # Print model size + model_size(encoder, 'Encoder') + model_size(decoder, 'Decoder') + + gc.collect() + + # GPU Setup + if torch.cuda.is_available(): + print('Using GPU') + gpu = True + # torch.cuda.set_device(args.gpu) + criterion = criterion.cuda() + encoder = encoder.cuda() + decoder = decoder.cuda() + char_decoder = char_decoder.cuda() + torch.cuda.empty_cache() + else: + gpu = False + + gpu = False + + # Mixed Precision Setup + if args.use_amp: + print('Using Mixed Precision') + # grad_scaler = GradScaler(enabled=args.use_amp) + + # Initialize Checkpoint + if args.load_checkpoint: + start_epoch, best_loss = load_checkpoint(encoder, decoder, optimizer, scheduler, args.checkpoint_path) + print(f'Resuming training from checkpoint starting at epoch {start_epoch}.') + else: + start_epoch = 0 + best_loss = float('inf') + + # Train Loop + # optimizer.zero_grad() + for epoch in range(start_epoch, args.epochs): + # torch.cuda.empty_cache() + + #variational noise for regularization + add_model_noise(encoder, std=args.variational_noise_std, gpu=gpu) + add_model_noise(decoder, std=args.variational_noise_std, gpu=gpu) + + # Train/Validation loops + wer, loss = train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, train_loader, args, gpu=gpu) + valid_wer, valid_loss = validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=gpu) + print(f'Epoch {epoch} - Valid WER: {valid_wer}%, Valid Loss: {valid_loss}, Train WER: {wer}%, Train Loss: {loss}') + + # Save checkpoint + if valid_loss <= best_loss: + print('Validation loss improved, saving checkpoint.') + best_loss = valid_loss + save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch+1, args.checkpoint_path) + +def train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, train_loader, args, gpu=True): + ''' Run a single training epoch ''' + + wer = WordErrorRate() + error_rate = AvgMeter() + avg_loss = AvgMeter() + text_transform = TextTransform() + + encoder.train() + decoder.train() + for i, batch in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths, references, mask = batch + print(spectrograms) + print(labels) + print(type(spectrograms)) + print('---------------') + scheduler.step() + gc.collect() + + # Move to GPU + if gpu: + spectrograms = spectrograms.cuda() + labels = labels.cuda() + input_lengths = torch.tensor(input_lengths).cuda() + label_lengths = torch.tensor(label_lengths).cuda() + mask = mask.cuda() + + # Update models + with autocast(enabled=args.use_amp): + outputs = encoder(spectrograms, mask) + outputs = decoder(outputs) + loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths) + # grad_scaler.scale(loss).backward() + if (i+1) % args.accumulate_iters == 0: + # grad_scaler.step(optimizer) + # grad_scaler.update() + optimizer.zero_grad() + avg_loss.update(loss.detach().item()) + + # Predict words, compute WER + inds = char_decoder(outputs.detach()) + predictions = [] + for sample in inds: + predictions.append(text_transform.int_to_text(sample)) + error_rate.update(wer(predictions, references) * 100) + + # Print metrics and predictions + if (i+1) % args.report_freq == 0: + print(f'Step {i+1} - Avg WER: {error_rate.avg}%, Avg Loss: {avg_loss.avg}') + print('Sample Predictions: ', predictions) + del spectrograms, labels, input_lengths, label_lengths, references, outputs, inds, predictions + return error_rate.avg, avg_loss.avg + +def validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=True): + ''' Evaluate model on test dataset. ''' + + avg_loss = AvgMeter() + error_rate = AvgMeter() + wer = WordErrorRate() + text_transform = TextTransform() + + encoder.eval() + decoder.eval() + for i, batch in enumerate(test_loader): + gc.collect() + spectrograms, labels, input_lengths, label_lengths, references, mask = batch + + # Move to GPU + if gpu: + spectrograms = spectrograms.cuda() + labels = labels.cuda() + input_lengths = torch.tensor(input_lengths).cuda() + label_lengths = torch.tensor(label_lengths).cuda() + mask = mask.cuda() + + with torch.no_grad(): + with autocast(enabled=args.use_amp): + outputs = encoder(spectrograms, mask) + outputs = decoder(outputs) + loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths) + avg_loss.update(loss.item()) + + inds = char_decoder(outputs.detach()) + predictions = [] + for sample in inds: + predictions.append(text_transform.int_to_text(sample)) + error_rate.update(wer(predictions, references) * 100) + return error_rate.avg, avg_loss.avg + + +def add_model_noise(model, std=0.0001, gpu=True): + ''' + Add variational noise to model weights: https://ieeexplore.ieee.org/abstract/document/548170 + STD may need some fine tuning... + ''' + # with torch.no_grad(): + for param in model.parameters(): + if gpu: + param.add_(torch.randn(param.size()) * std) + else: + param.add_(torch.randn(param.size()) * std) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/official/nlp/conformer/conformer-msa/utils.py b/official/nlp/conformer/conformer-msa/utils.py new file mode 100644 index 0000000..6923be3 --- /dev/null +++ b/official/nlp/conformer/conformer-msa/utils.py @@ -0,0 +1,223 @@ +import torchaudio +import torch +import torch.nn as nn +# import ms_adapter.pytorch as torch +# import ms_adapter.pytorch.nn as nn +import os +import random + +class TextTransform: + ''' Map characters to integers and vice versa ''' + def __init__(self): + self.char_map = {} + for i, char in enumerate(range(65, 91)): + self.char_map[chr(char)] = i + self.char_map["'"] = 26 + self.char_map[' '] = 27 + self.index_map = {} + for char, i in self.char_map.items(): + self.index_map[i] = char + + def text_to_int(self, text): + ''' Map text string to an integer sequence ''' + int_sequence = [] + for c in text: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + def int_to_text(self, labels): + ''' Map integer sequence to text string ''' + string = [] + for i in labels: + if i == 28: # blank char + continue + else: + string.append(self.index_map[i]) + return ''.join(string) + + +def get_audio_transforms(): + + # 10 time masks with p=0.05 + # The actual conformer paper uses a variable time_mask_param based on the length of each utterance. + # For simplicity, we approximate it with just a fixed value. + time_masks = [torchaudio.transforms.TimeMasking(time_mask_param=15, p=0.05) for _ in range(10)] + train_audio_transform = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80, hop_length=160), #80 filter banks, 25ms window size, 10ms hop + torchaudio.transforms.FrequencyMasking(freq_mask_param=27), + *time_masks, + ) + + valid_audio_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80, hop_length=160) + + return train_audio_transform, valid_audio_transform + +class BatchSampler(object): + ''' Sample contiguous, sorted indices. Leads to less padding and faster training. ''' + def __init__(self, sorted_inds, batch_size): + self.sorted_inds = sorted_inds + self.batch_size = batch_size + + def __iter__(self): + inds = self.sorted_inds.copy() + while len(inds): + to_take = min(self.batch_size, len(inds)) + start_ind = random.randint(0, len(inds) - to_take) + batch_inds = inds[start_ind:start_ind + to_take] + del inds[start_ind:start_ind + to_take] + yield batch_inds + +def preprocess_example(data, data_type="train"): + ''' Process raw LibriSpeech examples ''' + text_transform = TextTransform() + train_audio_transform, valid_audio_transform = get_audio_transforms() + spectrograms = [] + labels = [] + references = [] + input_lengths = [] + label_lengths = [] + for (waveform, _, utterance, _, _, _) in data: + # Generate spectrogram for model input + if data_type == 'train': + spec = train_audio_transform(waveform).squeeze(0).transpose(0, 1) # (1, time, freq) + else: + spec = valid_audio_transform(waveform).squeeze(0).transpose(0, 1) # (1, time, freq) + spectrograms.append(spec) + + # Labels + references.append(utterance) # Actual Sentence + label = torch.Tensor(text_transform.text_to_int(utterance)) # Integer representation of sentence + labels.append(label) + + # Lengths (time) + input_lengths.append(((spec.shape[0] - 1) // 2 - 1) // 2) # account for subsampling of time dimension + label_lengths.append(len(label)) + + # Pad batch to length of longest sample + spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + # Padding mask (batch_size, time, time) + mask = torch.ones(spectrograms.shape[0], spectrograms.shape[1], spectrograms.shape[1]) + for i, l in enumerate(input_lengths): + mask[i, :, :l] = 0 + + return spectrograms, labels, input_lengths, label_lengths, references, mask.bool() + +class TransformerLrScheduler(): + ''' + Transformer LR scheduler from "Attention is all you need." https://arxiv.org/abs/1706.03762 + multiplier and warmup_steps taken from conformer paper: https://arxiv.org/abs/2005.08100 + ''' + def __init__(self, optimizer, d_model, warmup_steps, multiplier=5): + self._optimizer = optimizer + self.d_model = d_model + self.warmup_steps = warmup_steps + self.n_steps = 0 + self.multiplier = multiplier + + def step(self): + self.n_steps += 1 + lr = self._get_lr() + for param_group in self._optimizer.param_groups: + param_group['lr'] = lr + + def _get_lr(self): + return self.multiplier * (self.d_model ** -0.5) * min(self.n_steps ** (-0.5), self.n_steps * (self.warmup_steps ** (-1.5))) + + +def model_size(model, name): + ''' Print model size in num_params and MB''' + param_size = 0 + num_params = 0 + for param in model.parameters(): + num_params += param.nelement() + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + num_params += buffer.nelement() + buffer_size += buffer.nelement() * buffer.element_size() + + size_all_mb = (param_size + buffer_size) / 1024**2 + print(f'{name} - num_params: {round(num_params / 1000000, 2)}M, size: {round(size_all_mb, 2)}MB') + + +class GreedyCharacterDecoder(nn.Module): + ''' Greedy CTC decoder - Argmax logits and remove duplicates. ''' + def __init__(self): + super(GreedyCharacterDecoder, self).__init__() + + def forward(self, x): + indices = torch.argmax(x, dim=-1) + indices = torch.unique_consecutive(indices, dim=-1) + return indices.tolist() + + +class AvgMeter(object): + ''' + Keep running average for a metric + ''' + def __init__(self): + self.reset() + + def reset(self): + self.avg = None + self.sum = None + self.cnt = 0 + + def update(self, val, n=1): + if not self.sum: + self.sum = val * n + else: + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def view_spectrogram(sample): + ''' View spectrogram ''' + specgram = sample.transpose(1, 0) + import matplotlib.pyplot as plt + plt.figure() + p = plt.imshow(specgram.log2()[:,:].detach().numpy(), cmap='gray') + plt.show() + +# def add_model_noise(model, std=0.0001, gpu=True): +# ''' +# Add variational noise to model weights: https://ieeexplore.ieee.org/abstract/document/548170 +# STD may need some fine tuning... +# ''' +# # with torch.no_grad(): +# for param in model.parameters(): +# if gpu: +# param.add_(torch.randn(param.size()) * std) +# else: +# param.add_(torch.randn(param.size()) * std) + + +def load_checkpoint(encoder, decoder, optimizer, scheduler, checkpoint_path): + ''' Load model checkpoint ''' + if not os.path.exists(checkpoint_path): + raise 'Checkpoint does not exist' + checkpoint = torch.load(checkpoint_path) + scheduler.n_steps = checkpoint['scheduler_n_steps'] + scheduler.multiplier = checkpoint['scheduler_multiplier'] + scheduler.warmup_steps = checkpoint['scheduler_warmup_steps'] + encoder.load_state_dict(checkpoint['encoder_state_dict']) + decoder.load_state_dict(checkpoint['decoder_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + return checkpoint['epoch'], checkpoint['valid_loss'] + +def save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch, checkpoint_path): + ''' Save model checkpoint ''' + torch.save({ + 'epoch': epoch, + 'valid_loss': valid_loss, + 'scheduler_n_steps': scheduler.n_steps, + 'scheduler_multiplier': scheduler.multiplier, + 'scheduler_warmup_steps': scheduler.warmup_steps, + 'encoder_state_dict': encoder.state_dict(), + 'decoder_state_dict': decoder.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, checkpoint_path) diff --git a/official/nlp/conformer/conformer-pytorch/LICENSE b/official/nlp/conformer/conformer-pytorch/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/official/nlp/conformer/conformer-pytorch/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/official/nlp/conformer/conformer-pytorch/README.md b/official/nlp/conformer/conformer-pytorch/README.md new file mode 100644 index 0000000..5824258 --- /dev/null +++ b/official/nlp/conformer/conformer-pytorch/README.md @@ -0,0 +1,32 @@ +# Pytorch Conformer +Pytorch implementation of [conformer](https://arxiv.org/abs/2005.08100) model with training script for end-to-end speech recognition on the LibriSpeech dataset. + +## Usage + +### Train model from scratch: +``` +python train.py --data_dir=./data --train_set=train-clean-100 --test_set=test_clean --checkpoint_path=model_best.pt +``` +### Resume training from checkpoint +``` +python train.py --load_checkpoint --checkpoint_path=model_best.pt +``` +### Train with mixed precision: +``` +python train.py --use_amp +``` + +For a full list of command line arguments, run ```python train.py --help```. [Smart batching](https://mccormickml.com/2020/07/29/smart-batching-tutorial/) is used by default but may need to be disabled for larger datasets. For valid train_set and test_set values, see torchaudio's [LibriSpeech dataset](https://pytorch.org/audio/stable/datasets.html). The model parameters default to the Conformer (S) configuration. For the Conformer (M) and Conformer (L) models, refer to the table below: + + + +## Other Implementations +- https://github.com/sooftware/conformer +- https://github.com/lucidrains/conformer + +## TODO: +- Language Model (LM) implementation +- Multi-GPU support +- Support for full LibriSpeech960h train set +- Support for other decoders (ie: transformer decoder, etc.) + diff --git a/official/nlp/conformer/conformer-pytorch/model.py b/official/nlp/conformer/conformer-pytorch/model.py new file mode 100644 index 0000000..e5a7f7b --- /dev/null +++ b/official/nlp/conformer/conformer-pytorch/model.py @@ -0,0 +1,368 @@ +import math +import torch +from torch import nn +import torch.nn.functional as F + +class PositionalEncoder(nn.Module): + ''' + Generate positional encodings used in the relative multi-head attention module. + These encodings are the same as the original transformer model: https://arxiv.org/abs/1706.03762 + + Parameters: + max_len (int): Maximum sequence length (time dimension) + + Inputs: + len (int): Length of encodings to retrieve + + Outputs + Tensor (len, d_model): Positional encodings + ''' + def __init__(self, d_model, max_len=10000): + super(PositionalEncoder, self).__init__() + self.d_model = d_model + encodings = torch.zeros(max_len, d_model) + pos = torch.arange(0, max_len, dtype=torch.float) + inv_freq = 1 / (10000 ** (torch.arange(0.0, d_model, 2.0) / d_model)) + encodings[:, 0::2] = torch.sin(pos[:, None] * inv_freq) + encodings[:, 1::2] = torch.cos(pos[:, None] * inv_freq) + self.register_buffer('encodings', encodings) + + def forward(self, len): + return self.encodings[:len, :] + +class RelativeMultiHeadAttention(nn.Module): + ''' + Relative Multi-Head Self-Attention Module. + Method proposed in Transformer-XL paper: https://arxiv.org/abs/1901.02860 + + Parameters: + d_model (int): Dimension of the model + num_heads (int): Number of heads to split inputs into + dropout (float): Dropout probability + positional_encoder (nn.Module): PositionalEncoder module + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the attention module. + + ''' + def __init__(self, d_model=144, num_heads=4, dropout=0.1, positional_encoder=PositionalEncoder(144)): + super(RelativeMultiHeadAttention, self).__init__() + + #dimensions + assert d_model % num_heads == 0 + self.d_model = d_model + self.d_head = d_model // num_heads + self.num_heads = num_heads + + # Linear projection weights + self.W_q = nn.Linear(d_model, d_model) + self.W_k = nn.Linear(d_model, d_model) + self.W_v = nn.Linear(d_model, d_model) + self.W_pos = nn.Linear(d_model, d_model, bias=False) + self.W_out = nn.Linear(d_model, d_model) + + # Trainable bias parameters + self.u = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u) + torch.nn.init.xavier_uniform_(self.v) + + # etc + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + self.positional_encoder = positional_encoder + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask=None): + batch_size, seq_length, _ = x.size() + + #layer norm and pos embeddings + x = self.layer_norm(x) + pos_emb = self.positional_encoder(seq_length) + pos_emb = pos_emb.repeat(batch_size, 1, 1) + + #Linear projections, split into heads + q = self.W_q(x).view(batch_size, seq_length, self.num_heads, self.d_head) + k = self.W_k(x).view(batch_size, seq_length, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + v = self.W_v(x).view(batch_size, seq_length, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + pos_emb = self.W_pos(pos_emb).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 3, 1) # (batch_size, num_heads, d_head, time) + + #Compute attention scores with relative position embeddings + AC = torch.matmul((q + self.u).transpose(1, 2), k) + BD = torch.matmul((q + self.v).transpose(1, 2), pos_emb) + BD = self.rel_shift(BD) + attn = (AC + BD) / math.sqrt(self.d_model) + + #Mask before softmax with large negative number + if mask is not None: + mask = mask.unsqueeze(1) + mask_value = -1e+30 if attn.dtype == torch.float32 else -1e+4 + attn.masked_fill_(mask, mask_value) + + #Softmax + attn = F.softmax(attn, -1) + + #Construct outputs from values + output = torch.matmul(attn, v.transpose(2, 3)).transpose(1, 2) # (batch_size, time, num_heads, d_head) + output = output.contiguous().view(batch_size, -1, self.d_model) # (batch_size, time, d_model) + + #Output projections and dropout + output = self.W_out(output) + return self.dropout(output) + + + def rel_shift(self, emb): + ''' + Pad and shift form relative positional encodings. + Taken from Transformer-XL implementation: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py + ''' + batch_size, num_heads, seq_length1, seq_length2 = emb.size() + zeros = emb.new_zeros(batch_size, num_heads, seq_length1, 1) + padded_emb = torch.cat([zeros, emb], dim=-1) + padded_emb = padded_emb.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + shifted_emb = padded_emb[:, :, 1:].view_as(emb) + return shifted_emb + + +class ConvBlock(nn.Module): + ''' + Conformer convolutional block. + + Parameters: + d_model (int): Dimension of the model + kernel_size (int): Size of kernel to use for depthwise convolution + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask: Unused + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the convolution module + + ''' + def __init__(self, d_model=144, kernel_size=31, dropout=0.1): + super(ConvBlock, self).__init__() + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + kernel_size=31 + self.module = nn.Sequential( + nn.Conv1d(in_channels=d_model, out_channels=d_model * 2, kernel_size=1), # first pointwise with 2x expansion + nn.GLU(dim=1), + nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=kernel_size, padding='same', groups=d_model), # depthwise + nn.BatchNorm1d(d_model, eps=6.1e-5), + nn.SiLU(), # swish activation + nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=1), # second pointwise + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.layer_norm(x) + x = x.transpose(1, 2) # (batch_size, d_model, seq_len) + x = self.module(x) + return x.transpose(1, 2) + +class FeedForwardBlock(nn.Module): + ''' + Conformer feed-forward block. + + Parameters: + d_model (int): Dimension of the model + expansion (int): Expansion factor for first linear layer + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask: Unused + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the feed-forward module + + ''' + def __init__(self, d_model=144, expansion=4, dropout=0.1): + super(FeedForwardBlock, self).__init__() + self.module = nn.Sequential( + nn.LayerNorm(d_model, eps=6.1e-5), + nn.Linear(d_model, d_model * expansion), # expand to d_model * expansion + nn.SiLU(), # swish activation + nn.Dropout(dropout), + nn.Linear(d_model * expansion, d_model), # project back to d_model + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.module(x) + +class Conv2dSubsampling(nn.Module): + ''' + 2d Convolutional subsampling. + Subsamples time and freq domains of input spectrograms by a factor of 4, d_model times. + + Parameters: + d_model (int): Dimension of the model + + Inputs: + x (Tensor): Input spectrogram (batch_size, time, d_input) + + Outputs: + Tensor (batch_size, time, d_model * (d_input // 4)): Output tensor from the conlutional subsampling module + + ''' + def __init__(self, d_model=144): + super(Conv2dSubsampling, self).__init__() + self.module = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=d_model, kernel_size=3, stride=2), + nn.ReLU(), + nn.Conv2d(in_channels=d_model, out_channels=d_model, kernel_size=3, stride=2), + nn.ReLU(), + ) + + def forward(self, x): + output = self.module(x.unsqueeze(1)) # (batch_size, 1, time, d_input) + batch_size, d_model, subsampled_time, subsampled_freq = output.size() + output = output.permute(0, 2, 1, 3) + output = output.contiguous().view(batch_size, subsampled_time, d_model * subsampled_freq) + return output + +class ConformerBlock(nn.Module): + ''' + Conformer Encoder Block. + + Parameters: + d_model (int): Dimension of the model + conv_kernel_size (int): Size of kernel to use for depthwise convolution + feed_forward_residual_factor (float): output_weight for feed-forward residual connections + feed_forward_expansion_factor (int): Expansion factor for feed-forward block + num_heads (int): Number of heads to use for multi-head attention + positional_encoder (nn.Module): PositionalEncoder module + dropout (float): Dropout probability + + Inputs: + x (Tensor): (batch_size, time, d_model) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the conformer block. + + ''' + def __init__( + self, + d_model=144, + conv_kernel_size=31, + feed_forward_residual_factor=.5, + feed_forward_expansion_factor=4, + num_heads=4, + positional_encoder=PositionalEncoder(144), + dropout=0.1, + ): + super(ConformerBlock, self).__init__() + self.residual_factor = feed_forward_residual_factor + self.ff1 = FeedForwardBlock(d_model, feed_forward_expansion_factor, dropout) + self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout, positional_encoder) + self.conv_block = ConvBlock(d_model, conv_kernel_size, dropout) + self.ff2 = FeedForwardBlock(d_model, feed_forward_expansion_factor, dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=6.1e-5) + + def forward(self, x, mask=None): + x = x + (self.residual_factor * self.ff1(x)) + x = x + self.attention(x, mask=mask) + x = x + self.conv_block(x) + x = x + (self.residual_factor * self.ff2(x)) + return self.layer_norm(x) + + +class ConformerEncoder(nn.Module): + ''' + Conformer Encoder Module. + + Parameters: + d_input (int): Dimension of the input + d_model (int): Dimension of the model + num_layers (int): Number of conformer blocks to use in the encoder + conv_kernel_size (int): Size of kernel to use for depthwise convolution + feed_forward_residual_factor (float): output_weight for feed-forward residual connections + feed_forward_expansion_factor (int): Expansion factor for feed-forward block + num_heads (int): Number of heads to use for multi-head attention + dropout (float): Dropout probability + + Inputs: + x (Tensor): input spectrogram of dimension (batch_size, time, d_input) + mask (Tensor): (batch_size, time, time) Optional mask to zero out attention score at certain indices + + Outputs: + Tensor (batch_size, time, d_model): Output tensor from the conformer encoder + + + ''' + def __init__( + self, + d_input=80, + d_model=144, + num_layers=16, + conv_kernel_size=31, + feed_forward_residual_factor=.5, + feed_forward_expansion_factor=4, + num_heads=4, + dropout=.1, + ): + super(ConformerEncoder, self).__init__() + self.conv_subsample = Conv2dSubsampling(d_model=d_model) + self.linear_proj = nn.Linear(d_model * (((d_input - 1) // 2 - 1) // 2), d_model) # project subsamples to d_model + self.dropout = nn.Dropout(p=dropout) + + # define global positional encoder to limit model parameters + positional_encoder = PositionalEncoder(d_model) + self.layers = nn.ModuleList([ConformerBlock( + d_model=d_model, + conv_kernel_size=conv_kernel_size, + feed_forward_residual_factor=feed_forward_residual_factor, + feed_forward_expansion_factor=feed_forward_expansion_factor, + num_heads=num_heads, + positional_encoder=positional_encoder, + dropout=dropout, + ) for _ in range(num_layers)]) + + def forward(self, x, mask=None): + x = self.conv_subsample(x) + if mask is not None: + mask = mask[:, :-2:2, :-2:2] #account for subsampling + mask = mask[:, :-2:2, :-2:2] #account for subsampling + assert mask.shape[1] == x.shape[1], f'{mask.shape} {x.shape}' + + x = self.linear_proj(x) + x = self.dropout(x) + + for layer in self.layers: + x = layer(x, mask=mask) + + return x + + +class LSTMDecoder(nn.Module): + ''' + LSTM Decoder + + Parameters: + d_encoder (int): Output dimension of the encoder + d_decoder (int): Hidden dimension of the decoder + num_layers (int): Number of LSTM layers to use in the decoder + num_classes (int): Number of output classes to predict + + Inputs: + x (Tensor): (batch_size, time, d_encoder) + + Outputs: + Tensor (batch_size, time, num_classes): Class prediction logits + + ''' + def __init__(self, d_encoder=144, d_decoder=320, num_layers=1, num_classes=29): + super(LSTMDecoder, self).__init__() + self.lstm = nn.LSTM(input_size=d_encoder, hidden_size=d_decoder, num_layers=num_layers, batch_first=True) + self.linear = nn.Linear(d_decoder, num_classes) + + def forward(self, x): + x, _ = self.lstm(x) + logits = self.linear(x) + return logits diff --git a/official/nlp/conformer/conformer-pytorch/train.py b/official/nlp/conformer/conformer-pytorch/train.py new file mode 100644 index 0000000..3008af8 --- /dev/null +++ b/official/nlp/conformer/conformer-pytorch/train.py @@ -0,0 +1,252 @@ +import os +import gc +import argparse +import torchaudio +import torch +import torch.nn.functional as F + +from torch import nn +from torchmetrics.text.wer import WordErrorRate +from torch.utils.data import DataLoader +from torch.cuda.amp import autocast, GradScaler +from model import ConformerEncoder, LSTMDecoder +from utils import * + +parser = argparse.ArgumentParser("conformer") +parser.add_argument('--data_dir', type=str, default='./data', help='location to download data') +parser.add_argument('--checkpoint_path', type=str, default='model_best.pt', help='path to store/load checkpoints') +parser.add_argument('--load_checkpoint', action='store_true', default=False, help='resume training from checkpoint') +parser.add_argument('--train_set', type=str, default='train-clean-100', help='train dataset') +parser.add_argument('--test_set', type=str, default='test-clean', help='test dataset') +parser.add_argument('--batch_size', type=int, default=1, help='batch size') +parser.add_argument('--warmup_steps', type=float, default=10000, help='Multiply by sqrt(d_model) to get max_lr') +parser.add_argument('--peak_lr_ratio', type=int, default=0.05, help='Number of warmup steps for LR scheduler') +parser.add_argument('--gpu', type=int, default=0, help='gpu device id (optional)') +parser.add_argument('--epochs', type=int, default=50, help='num of training epochs') +parser.add_argument('--report_freq', type=int, default=100, help='training objective report frequency') +parser.add_argument('--layers', type=int, default=8, help='total number of layers') +parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') +parser.add_argument('--use_amp', action='store_true', default=False, help='use mixed precision to train') +parser.add_argument('--attention_heads', type=int, default=4, help='number of heads to use for multi-head attention') +parser.add_argument('--d_input', type=int, default=80, help='dimension of the input (num filter banks)') +# parser.add_argument('--d_encoder', type=int, default=144, help='dimension of the encoder') +parser.add_argument('--d_encoder', type=int, default=20, help='dimension of the encoder') +# parser.add_argument('--d_decoder', type=int, default=320, help='dimension of the decoder') +parser.add_argument('--d_decoder', type=int, default=80, help='dimension of the decoder') +parser.add_argument('--encoder_layers', type=int, default=16, help='number of conformer blocks in the encoder') +parser.add_argument('--decoder_layers', type=int, default=1, help='number of decoder layers') +parser.add_argument('--conv_kernel_size', type=int, default=31, help='size of kernel for conformer convolution blocks') +parser.add_argument('--feed_forward_expansion_factor', type=int, default=4, help='expansion factor for conformer feed forward blocks') +parser.add_argument('--feed_forward_residual_factor', type=int, default=.5, help='residual factor for conformer feed forward blocks') +parser.add_argument('--dropout', type=float, default=.1, help='dropout factor for conformer model') +parser.add_argument('--weight_decay', type=float, default=1e-6, help='model weight decay (corresponds to L2 regularization)') +parser.add_argument('--variational_noise_std', type=float, default=.0001, help='std of noise added to model weights for regularization') +parser.add_argument('--num_workers', type=int, default=2, help='num_workers for the dataloader') +parser.add_argument('--smart_batch', type=bool, default=True, help='Use smart batching for faster training') +parser.add_argument('--accumulate_iters', type=int, default=1, help='Number of iterations to accumulate gradients') +args = parser.parse_args() + + +def main(): + + # Load Data + if not os.path.isdir(args.data_dir): + os.mkdir(args.data_dir) + train_data = torchaudio.datasets.LIBRISPEECH(root=args.data_dir, url=args.train_set) + test_data = torchaudio.datasets.LIBRISPEECH(args.data_dir, url=args.test_set) + + + if args.smart_batch: + print('Sorting training data for smart batching...') + sorted_train_inds = [ind for ind, _ in sorted(enumerate(train_data), key=lambda x: x[1][0].shape[1])] + sorted_test_inds = [ind for ind, _ in sorted(enumerate(test_data), key=lambda x: x[1][0].shape[1])] + train_loader = DataLoader(dataset=train_data, + pin_memory=True, + num_workers=args.num_workers, + batch_sampler=BatchSampler(sorted_train_inds, batch_size=args.batch_size), + collate_fn=lambda x: preprocess_example(x, 'train')) + + test_loader = DataLoader(dataset=test_data, + pin_memory=True, + num_workers=args.num_workers, + batch_sampler=BatchSampler(sorted_test_inds, batch_size=args.batch_size), + collate_fn=lambda x: preprocess_example(x, 'valid')) + else: + train_loader = DataLoader(dataset=train_data, + pin_memory=True, + num_workers=args.num_workers, + batch_size=args.batch_size, + shuffle=True, + collate_fn=lambda x: preprocess_example(x, 'train')) + + test_loader = DataLoader(dataset=test_data, + pin_memory=True, + num_workers=args.num_workers, + batch_size=args.batch_size, + shuffle=False, + collate_fn=lambda x: preprocess_example(x, 'valid')) + + + # Declare Models + + encoder = ConformerEncoder( + d_input=args.d_input, + d_model=args.d_encoder, + num_layers=args.encoder_layers, + conv_kernel_size=args.conv_kernel_size, + dropout=args.dropout, + feed_forward_residual_factor=args.feed_forward_residual_factor, + feed_forward_expansion_factor=args.feed_forward_expansion_factor, + num_heads=args.attention_heads) + + decoder = LSTMDecoder( + d_encoder=args.d_encoder, + d_decoder=args.d_decoder, + num_layers=args.decoder_layers) + char_decoder = GreedyCharacterDecoder().eval() + criterion = nn.CTCLoss(blank=28, zero_infinity=True) + optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=5e-4, betas=(.9, .98), eps=1e-05 if args.use_amp else 1e-09, weight_decay=args.weight_decay) + scheduler = TransformerLrScheduler(optimizer, args.d_encoder, args.warmup_steps) + + # Print model size + model_size(encoder, 'Encoder') + model_size(decoder, 'Decoder') + + gc.collect() + + # # GPU Setup + # if torch.cuda.is_available(): + # print('Using GPU') + # gpu = True + # # torch.cuda.set_device(args.gpu) + # # criterion = criterion.cuda() + # # encoder = encoder.cuda() + # # decoder = decoder.cuda() + # # char_decoder = char_decoder.cuda() + # # torch.cuda.empty_cache() + # else: + # gpu = False + + gpu = False + + # Mixed Precision Setup + if args.use_amp: + print('Using Mixed Precision') + grad_scaler = GradScaler(enabled=args.use_amp) + + # Initialize Checkpoint + if args.load_checkpoint: + start_epoch, best_loss = load_checkpoint(encoder, decoder, optimizer, scheduler, args.checkpoint_path) + print(f'Resuming training from checkpoint starting at epoch {start_epoch}.') + else: + start_epoch = 0 + best_loss = float('inf') + + # Train Loop + optimizer.zero_grad() + for epoch in range(start_epoch, args.epochs): + torch.cuda.empty_cache() + + #variational noise for regularization + add_model_noise(encoder, std=args.variational_noise_std, gpu=gpu) + add_model_noise(decoder, std=args.variational_noise_std, gpu=gpu) + + # Train/Validation loops + wer, loss = train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, grad_scaler, train_loader, args, gpu=gpu) + valid_wer, valid_loss = validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=gpu) + print(f'Epoch {epoch} - Valid WER: {valid_wer}%, Valid Loss: {valid_loss}, Train WER: {wer}%, Train Loss: {loss}') + + # Save checkpoint + if valid_loss <= best_loss: + print('Validation loss improved, saving checkpoint.') + best_loss = valid_loss + save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch+1, args.checkpoint_path) + +def train(encoder, decoder, char_decoder, optimizer, scheduler, criterion, grad_scaler, train_loader, args, gpu=True): + ''' Run a single training epoch ''' + + wer = WordErrorRate() + error_rate = AvgMeter() + avg_loss = AvgMeter() + text_transform = TextTransform() + + encoder.train() + decoder.train() + for i, batch in enumerate(train_loader): + scheduler.step() + gc.collect() + spectrograms, labels, input_lengths, label_lengths, references, mask = batch + + # Move to GPU + if gpu: + spectrograms = spectrograms.cuda() + labels = labels.cuda() + input_lengths = torch.tensor(input_lengths).cuda() + label_lengths = torch.tensor(label_lengths).cuda() + mask = mask.cuda() + + # Update models + with autocast(enabled=args.use_amp): + outputs = encoder(spectrograms, mask) + outputs = decoder(outputs) + loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths) + grad_scaler.scale(loss).backward() + if (i+1) % args.accumulate_iters == 0: + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad() + avg_loss.update(loss.detach().item()) + + # Predict words, compute WER + inds = char_decoder(outputs.detach()) + predictions = [] + for sample in inds: + predictions.append(text_transform.int_to_text(sample)) + error_rate.update(wer(predictions, references) * 100) + + # Print metrics and predictions + if (i+1) % args.report_freq == 0: + print(f'Step {i+1} - Avg WER: {error_rate.avg}%, Avg Loss: {avg_loss.avg}') + print('Sample Predictions: ', predictions) + del spectrograms, labels, input_lengths, label_lengths, references, outputs, inds, predictions + return error_rate.avg, avg_loss.avg + +def validate(encoder, decoder, char_decoder, criterion, test_loader, args, gpu=True): + ''' Evaluate model on test dataset. ''' + + avg_loss = AvgMeter() + error_rate = AvgMeter() + wer = WordErrorRate() + text_transform = TextTransform() + + encoder.eval() + decoder.eval() + for i, batch in enumerate(test_loader): + gc.collect() + spectrograms, labels, input_lengths, label_lengths, references, mask = batch + + # Move to GPU + if gpu: + spectrograms = spectrograms.cuda() + labels = labels.cuda() + input_lengths = torch.tensor(input_lengths).cuda() + label_lengths = torch.tensor(label_lengths).cuda() + mask = mask.cuda() + + with torch.no_grad(): + with autocast(enabled=args.use_amp): + outputs = encoder(spectrograms, mask) + outputs = decoder(outputs) + loss = criterion(F.log_softmax(outputs, dim=-1).transpose(0, 1), labels, input_lengths, label_lengths) + avg_loss.update(loss.item()) + + inds = char_decoder(outputs.detach()) + predictions = [] + for sample in inds: + predictions.append(text_transform.int_to_text(sample)) + error_rate.update(wer(predictions, references) * 100) + return error_rate.avg, avg_loss.avg + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/official/nlp/conformer/conformer-pytorch/utils.py b/official/nlp/conformer/conformer-pytorch/utils.py new file mode 100644 index 0000000..3fa7f8e --- /dev/null +++ b/official/nlp/conformer/conformer-pytorch/utils.py @@ -0,0 +1,221 @@ +import torchaudio +import torch +import torch.nn as nn +import os +import random + +class TextTransform: + ''' Map characters to integers and vice versa ''' + def __init__(self): + self.char_map = {} + for i, char in enumerate(range(65, 91)): + self.char_map[chr(char)] = i + self.char_map["'"] = 26 + self.char_map[' '] = 27 + self.index_map = {} + for char, i in self.char_map.items(): + self.index_map[i] = char + + def text_to_int(self, text): + ''' Map text string to an integer sequence ''' + int_sequence = [] + for c in text: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + def int_to_text(self, labels): + ''' Map integer sequence to text string ''' + string = [] + for i in labels: + if i == 28: # blank char + continue + else: + string.append(self.index_map[i]) + return ''.join(string) + + +def get_audio_transforms(): + + # 10 time masks with p=0.05 + # The actual conformer paper uses a variable time_mask_param based on the length of each utterance. + # For simplicity, we approximate it with just a fixed value. + time_masks = [torchaudio.transforms.TimeMasking(time_mask_param=15, p=0.05) for _ in range(10)] + train_audio_transform = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80, hop_length=160), #80 filter banks, 25ms window size, 10ms hop + torchaudio.transforms.FrequencyMasking(freq_mask_param=27), + *time_masks, + ) + + valid_audio_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=80, hop_length=160) + + return train_audio_transform, valid_audio_transform + +class BatchSampler(object): + ''' Sample contiguous, sorted indices. Leads to less padding and faster training. ''' + def __init__(self, sorted_inds, batch_size): + self.sorted_inds = sorted_inds + self.batch_size = batch_size + + def __iter__(self): + inds = self.sorted_inds.copy() + while len(inds): + to_take = min(self.batch_size, len(inds)) + start_ind = random.randint(0, len(inds) - to_take) + batch_inds = inds[start_ind:start_ind + to_take] + del inds[start_ind:start_ind + to_take] + yield batch_inds + +def preprocess_example(data, data_type="train"): + ''' Process raw LibriSpeech examples ''' + text_transform = TextTransform() + train_audio_transform, valid_audio_transform = get_audio_transforms() + spectrograms = [] + labels = [] + references = [] + input_lengths = [] + label_lengths = [] + for (waveform, _, utterance, _, _, _) in data: + # Generate spectrogram for model input + if data_type == 'train': + spec = train_audio_transform(waveform).squeeze(0).transpose(0, 1) # (1, time, freq) + else: + spec = valid_audio_transform(waveform).squeeze(0).transpose(0, 1) # (1, time, freq) + spectrograms.append(spec) + + # Labels + references.append(utterance) # Actual Sentence + label = torch.Tensor(text_transform.text_to_int(utterance)) # Integer representation of sentence + labels.append(label) + + # Lengths (time) + input_lengths.append(((spec.shape[0] - 1) // 2 - 1) // 2) # account for subsampling of time dimension + label_lengths.append(len(label)) + + # Pad batch to length of longest sample + spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + # Padding mask (batch_size, time, time) + mask = torch.ones(spectrograms.shape[0], spectrograms.shape[1], spectrograms.shape[1]) + for i, l in enumerate(input_lengths): + mask[i, :, :l] = 0 + + return spectrograms, labels, input_lengths, label_lengths, references, mask.bool() + +class TransformerLrScheduler(): + ''' + Transformer LR scheduler from "Attention is all you need." https://arxiv.org/abs/1706.03762 + multiplier and warmup_steps taken from conformer paper: https://arxiv.org/abs/2005.08100 + ''' + def __init__(self, optimizer, d_model, warmup_steps, multiplier=5): + self._optimizer = optimizer + self.d_model = d_model + self.warmup_steps = warmup_steps + self.n_steps = 0 + self.multiplier = multiplier + + def step(self): + self.n_steps += 1 + lr = self._get_lr() + for param_group in self._optimizer.param_groups: + param_group['lr'] = lr + + def _get_lr(self): + return self.multiplier * (self.d_model ** -0.5) * min(self.n_steps ** (-0.5), self.n_steps * (self.warmup_steps ** (-1.5))) + + +def model_size(model, name): + ''' Print model size in num_params and MB''' + param_size = 0 + num_params = 0 + for param in model.parameters(): + num_params += param.nelement() + param_size += param.nelement() * param.element_size() + buffer_size = 0 + for buffer in model.buffers(): + num_params += buffer.nelement() + buffer_size += buffer.nelement() * buffer.element_size() + + size_all_mb = (param_size + buffer_size) / 1024**2 + print(f'{name} - num_params: {round(num_params / 1000000, 2)}M, size: {round(size_all_mb, 2)}MB') + + +class GreedyCharacterDecoder(nn.Module): + ''' Greedy CTC decoder - Argmax logits and remove duplicates. ''' + def __init__(self): + super(GreedyCharacterDecoder, self).__init__() + + def forward(self, x): + indices = torch.argmax(x, dim=-1) + indices = torch.unique_consecutive(indices, dim=-1) + return indices.tolist() + + +class AvgMeter(object): + ''' + Keep running average for a metric + ''' + def __init__(self): + self.reset() + + def reset(self): + self.avg = None + self.sum = None + self.cnt = 0 + + def update(self, val, n=1): + if not self.sum: + self.sum = val * n + else: + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def view_spectrogram(sample): + ''' View spectrogram ''' + specgram = sample.transpose(1, 0) + import matplotlib.pyplot as plt + plt.figure() + p = plt.imshow(specgram.log2()[:,:].detach().numpy(), cmap='gray') + plt.show() + +def add_model_noise(model, std=0.0001, gpu=True): + ''' + Add variational noise to model weights: https://ieeexplore.ieee.org/abstract/document/548170 + STD may need some fine tuning... + ''' + with torch.no_grad(): + for param in model.parameters(): + if gpu: + param.add_(torch.randn(param.size()) * std) + else: + param.add_(torch.randn(param.size()) * std) + + +def load_checkpoint(encoder, decoder, optimizer, scheduler, checkpoint_path): + ''' Load model checkpoint ''' + if not os.path.exists(checkpoint_path): + raise 'Checkpoint does not exist' + checkpoint = torch.load(checkpoint_path) + scheduler.n_steps = checkpoint['scheduler_n_steps'] + scheduler.multiplier = checkpoint['scheduler_multiplier'] + scheduler.warmup_steps = checkpoint['scheduler_warmup_steps'] + encoder.load_state_dict(checkpoint['encoder_state_dict']) + decoder.load_state_dict(checkpoint['decoder_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + return checkpoint['epoch'], checkpoint['valid_loss'] + +def save_checkpoint(encoder, decoder, optimizer, scheduler, valid_loss, epoch, checkpoint_path): + ''' Save model checkpoint ''' + torch.save({ + 'epoch': epoch, + 'valid_loss': valid_loss, + 'scheduler_n_steps': scheduler.n_steps, + 'scheduler_multiplier': scheduler.multiplier, + 'scheduler_warmup_steps': scheduler.warmup_steps, + 'encoder_state_dict': encoder.state_dict(), + 'decoder_state_dict': decoder.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + }, checkpoint_path) -- 2.34.1