From d2323f0d98c2976118c4ff94283b917ae607a32c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Sep 2021 14:39:35 +0000 Subject: [PATCH] Implement DeepSpeech --- TTS/stt/models/deep_speech.py | 214 ++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 TTS/stt/models/deep_speech.py diff --git a/TTS/stt/models/deep_speech.py b/TTS/stt/models/deep_speech.py new file mode 100644 index 00000000..8f1560a9 --- /dev/null +++ b/TTS/stt/models/deep_speech.py @@ -0,0 +1,214 @@ +from dataclasses import dataclass +from typing import Union + +import jiwer +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as f + +from TTS.stt.datasets.tokenizer import Tokenizer +from TTS.stt.models.base_stt import BaseSTT +from TTS.tts.utils.helpers import sequence_mask + + +def compute_wer(logits, targets, logits_lenghts, target_lengths, tokenizer): + pred_ids = np.argmax(logits, axis=-1) + + pred_texts = [] + for idx, pi in enumerate(pred_ids): + pi = pi[: logits_lenghts[idx]] + tokens = tokenizer.decode(pi) + text = tokenizer.tokens_to_string(tokens) + pred_texts.append("".join(text)) + + label_texts = [] + for idx, pi in enumerate(targets): + label = "".join(tokenizer.decode(pi)) + label = label[: target_lengths[idx]] + label = label.replace(tokenizer.word_del_token, " ") + label_texts.append(label) + + wer = jiwer.wer(label_texts, pred_texts) + return wer, label_texts, pred_texts + + +@dataclass +class DeepSpeechArgs(Coqpit): + n_tokens: int = 0 # The number of characters and +1 for CTC blank label + n_context: int = ( + 9 # The number of frames in the context - translates to the windows size of 2 * n_context + n_input. + ) + context_step: int = 1 # The number of steps take between each context window. + n_input: int = 26 # Number of MFCC features - TODO: Determine this programmatically from the sample rate + n_hidden_1: int = 2048 + dropout_1: float = 0.1 + n_hidden_2: int = 2048 + dropout_2: float = 0.1 + n_hidden_3: int = 2048 # The number of units in the third layer, which feeds in to the LSTM + dropout_3: float = 0.1 + n_hidden_4: int = 2048 + n_hidden_5: int = 2048 + dropout_5: float = 0.1 + layer_norm: bool = False # TODO: add layer norm layers to the model + relu_clip: float = 20.0 + + +class ClippedReLU(nn.Module): + def __init__(self, clip=20.0): + super().__init__() + self.clip = clip + + def forward(self, x): + return torch.clip(f.relu(x), max=self.clip) + + +class DeepSpeech(BaseSTT): + def __init__(self, config: Union["DeepSpeechConfig", DeepSpeechArgs]): + super().__init__(config) + + if hasattr(self, "config"): + self.tokenizer = Tokenizer(vocab_dict=self.config.vocabulary) + + self.layer_1 = nn.Linear((self.args.n_context * 2 + 1) * self.args.n_input, self.args.n_hidden_1) + self.dropout_1 = nn.Dropout(p=self.args.dropout_1) + self.relu_1 = ClippedReLU(self.args.relu_clip) + self.layer_2 = nn.Linear(self.args.n_hidden_1, self.args.n_hidden_2) + self.dropout_2 = nn.Dropout(p=self.args.dropout_2) + self.relu_2 = ClippedReLU(self.args.relu_clip) + self.layer_3 = nn.Linear(self.args.n_hidden_2, self.args.n_hidden_3) + self.dropout_3 = nn.Dropout(p=self.args.dropout_3) + self.relu_3 = ClippedReLU(self.args.relu_clip) + self.lstm = nn.LSTM(self.args.n_hidden_3, self.args.n_hidden_4, num_layers=1, batch_first=True) + self.layer_5 = nn.Linear(self.args.n_hidden_4, self.args.n_hidden_5) + self.dropout_5 = nn.Dropout(p=self.args.dropout_5) + self.relu_5 = ClippedReLU(self.args.relu_clip) + self.layer_6 = nn.Linear(self.args.n_hidden_5, self.args.n_tokens) + + @staticmethod + def logits_to_text(logits, tokenizer): + pred_ids = np.argmax(logits, axis=-1) + + pred_texts = [] + for pi in pred_ids: + pred_texts.append("".join(tokenizer.decode(pi))) + return pred_texts + + def reshape_input(self, x: torch.Tensor) -> torch.Tensor: + """Format the input tensor to define the context windows. + + Args: + x (Tensor): input tensor. + + Returns: + Tensor: formatted tensor with the last dimension is the context window. + + Shapes: + x : :math:`[B, T, C]` + Tensor: :math:`[B, T/context_step, C*2*n_context+C]` + """ + # TODO: might need to pad x before unfold for the last window + x = x.unfold(1, 2 * self.args.n_context + 1, self.args.context_step).permute(0, 1, 3, 2) + x = x.reshape(x.shape[0], -1, self.args.n_input * 2 * self.args.n_context + self.args.n_input) + return x + + def recompute_input_lengths(self, x): + return ((x - (2 * self.args.n_context + 1)) / self.args.context_step) + 1 + + def forward(self, x, previous_states: tuple[torch.Tensor, torch.Tensor]): + x = self.layer_1(x) + x = self.dropout_1(x) + x = self.relu_1(x) + x = self.layer_2(x) + x = self.dropout_2(x) + x = self.relu_2(x) + x = self.layer_3(x) + x = self.dropout_3(x) + x = self.relu_3(x) + new_state_h, new_state_c = self.lstm(x, previous_states) + x = self.layer_5(new_state_h) + x = self.dropout_5(x) + x = self.relu_5(x) + x = self.layer_6(x) + x = x.log_softmax(dim=-1) + return {"model_outputs": x, "model_states": (new_state_h, new_state_c)} + + @torch.no_grad() + def inference(self, x, aux_inputs={"previous_states": None}): + return self.forward(x, aux_inputs["previous_states"]) + + def train_step(self, batch, criterion): + outputs = self.forward(batch["features"], None) + + if self.config.loss_masking: + output_mask = sequence_mask(batch["feature_lengths"]) + outputs["model_outputs"] = outputs["model_outputs"].masked_fill( + ~output_mask.unsqueeze(-1), self.tokenizer.pad_token_id + ) + target_mask = sequence_mask(batch["token_lengths"]) + target_lengths = target_mask.sum(-1) + flattened_targets = batch["tokens"].masked_select(target_mask) + + # with autocast(enabled=False): # avoid mixed_precision in criterion + with torch.backends.cudnn.flags(enabled=False): + loss = criterion( + outputs["model_outputs"].transpose(0, 1), # [B, T, C] -> [T, B, C] + flattened_targets, + batch["feature_lengths"], + target_lengths, + ) + + wer, target_texts, pred_texts = compute_wer( + outputs["model_outputs"].detach().cpu().numpy(), + batch["tokens"].cpu().numpy(), + batch["feature_lengths"].cpu().numpy(), + batch["token_lengths"].cpu().numpy(), + self.tokenizer, + ) + + # for visualization + outputs["target_texts"] = target_texts + outputs["pred_texts"] = pred_texts + + loss_dict = {"loss": loss, "wer": wer} + return outputs, loss_dict + + def eval_step(self, batch, criterion): + return self.train_step(batch, criterion) + + def train_log(self, batch, outputs, dashboard_logger, assets, step, prefix="Train - "): + pred_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["pred_texts"])]) + target_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["target_texts"])]) + + dashboard_logger.add_text(f"{prefix} - Last Batch Predictions:", f"
{pred_log_text}
", step) + dashboard_logger.add_text(f"{prefix} - Last Batch Target:", f"
{target_log_text}
", step) + + def eval_log(self, batch, outputs, dashboard_logger, assets, step): + self.train_log(batch, outputs, dashboard_logger, assets, step, prefix="Eval - ") + + def get_criterion(self): + ctc_loss = nn.CTCLoss(blank=self.config.vocabulary["[PAD]"], reduction="mean") + return ctc_loss + + def format_batch(self, batch): + batch["features"] = self.reshape_input(batch["features"]) + batch["feature_lengths"] = self.recompute_input_lengths(batch["feature_lengths"]).long() + assert ( + batch["feature_lengths"].max() == batch["features"].shape[1] + ), f"{batch['feature_lengths'].max()} vs {batch['features'].shape[1]}" + return batch + + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: + return None + + +if __name__ == "__main__": + + args = DeepSpeechArgs(n_input=23, n_tokens=13) + model = DeepSpeech(args) + input_tensor = torch.rand(2, 57, 23) + input_tensor = model.reshape_input(input_tensor) + outputs = model.forward(input_tensor, None) + print(outputs["model_outputs"].shape)