from dataclasses import dataclass from typing import Union import jiwer import numpy as np import torch from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as f from TTS.stt.datasets.tokenizer import Tokenizer from TTS.stt.models.base_stt import BaseSTT from TTS.tts.utils.helpers import sequence_mask def compute_wer(logits, targets, logits_lenghts, target_lengths, tokenizer): pred_ids = np.argmax(logits, axis=-1) pred_texts = [] for idx, pi in enumerate(pred_ids): pi = pi[: logits_lenghts[idx]] tokens = tokenizer.decode(pi) text = tokenizer.tokens_to_string(tokens) pred_texts.append("".join(text)) label_texts = [] for idx, pi in enumerate(targets): label = "".join(tokenizer.decode(pi)) label = label[: target_lengths[idx]] label = label.replace(tokenizer.word_del_token, " ") label_texts.append(label) wer = jiwer.wer(label_texts, pred_texts) return wer, label_texts, pred_texts @dataclass class DeepSpeechArgs(Coqpit): n_tokens: int = 0 # The number of characters and +1 for CTC blank label n_context: int = ( 9 # The number of frames in the context - translates to the windows size of 2 * n_context + n_input. ) context_step: int = 1 # The number of steps take between each context window. n_input: int = 26 # Number of MFCC features - TODO: Determine this programmatically from the sample rate n_hidden_1: int = 2048 dropout_1: float = 0.1 n_hidden_2: int = 2048 dropout_2: float = 0.1 n_hidden_3: int = 2048 # The number of units in the third layer, which feeds in to the LSTM dropout_3: float = 0.1 n_hidden_4: int = 2048 n_hidden_5: int = 2048 dropout_5: float = 0.1 layer_norm: bool = False # TODO: add layer norm layers to the model relu_clip: float = 20.0 class ClippedReLU(nn.Module): def __init__(self, clip=20.0): super().__init__() self.clip = clip def forward(self, x): return torch.clip(f.relu(x), max=self.clip) class DeepSpeech(BaseSTT): def __init__(self, config: Union["DeepSpeechConfig", DeepSpeechArgs]): super().__init__(config) if hasattr(self, "config"): self.tokenizer = Tokenizer(vocab_dict=self.config.vocabulary) self.layer_1 = nn.Linear((self.args.n_context * 2 + 1) * self.args.n_input, self.args.n_hidden_1) self.dropout_1 = nn.Dropout(p=self.args.dropout_1) self.relu_1 = ClippedReLU(self.args.relu_clip) self.layer_2 = nn.Linear(self.args.n_hidden_1, self.args.n_hidden_2) self.dropout_2 = nn.Dropout(p=self.args.dropout_2) self.relu_2 = ClippedReLU(self.args.relu_clip) self.layer_3 = nn.Linear(self.args.n_hidden_2, self.args.n_hidden_3) self.dropout_3 = nn.Dropout(p=self.args.dropout_3) self.relu_3 = ClippedReLU(self.args.relu_clip) self.lstm = nn.LSTM(self.args.n_hidden_3, self.args.n_hidden_4, num_layers=1, batch_first=True) self.layer_5 = nn.Linear(self.args.n_hidden_4, self.args.n_hidden_5) self.dropout_5 = nn.Dropout(p=self.args.dropout_5) self.relu_5 = ClippedReLU(self.args.relu_clip) self.layer_6 = nn.Linear(self.args.n_hidden_5, self.args.n_tokens) @staticmethod def logits_to_text(logits, tokenizer): pred_ids = np.argmax(logits, axis=-1) pred_texts = [] for pi in pred_ids: pred_texts.append("".join(tokenizer.decode(pi))) return pred_texts def reshape_input(self, x: torch.Tensor) -> torch.Tensor: """Format the input tensor to define the context windows. Args: x (Tensor): input tensor. Returns: Tensor: formatted tensor with the last dimension is the context window. Shapes: x : :math:`[B, T, C]` Tensor: :math:`[B, T/context_step, C*2*n_context+C]` """ # TODO: might need to pad x before unfold for the last window x = x.unfold(1, 2 * self.args.n_context + 1, self.args.context_step).permute(0, 1, 3, 2) x = x.reshape(x.shape[0], -1, self.args.n_input * 2 * self.args.n_context + self.args.n_input) return x def recompute_input_lengths(self, x): return ((x - (2 * self.args.n_context + 1)) / self.args.context_step) + 1 def forward(self, x, previous_states: tuple[torch.Tensor, torch.Tensor]): x = self.layer_1(x) x = self.dropout_1(x) x = self.relu_1(x) x = self.layer_2(x) x = self.dropout_2(x) x = self.relu_2(x) x = self.layer_3(x) x = self.dropout_3(x) x = self.relu_3(x) new_state_h, new_state_c = self.lstm(x, previous_states) x = self.layer_5(new_state_h) x = self.dropout_5(x) x = self.relu_5(x) x = self.layer_6(x) x = x.log_softmax(dim=-1) return {"model_outputs": x, "model_states": (new_state_h, new_state_c)} @torch.no_grad() def inference(self, x, aux_inputs={"previous_states": None}): return self.forward(x, aux_inputs["previous_states"]) def train_step(self, batch, criterion): outputs = self.forward(batch["features"], None) if self.config.loss_masking: output_mask = sequence_mask(batch["feature_lengths"]) outputs["model_outputs"] = outputs["model_outputs"].masked_fill( ~output_mask.unsqueeze(-1), self.tokenizer.pad_token_id ) target_mask = sequence_mask(batch["token_lengths"]) target_lengths = target_mask.sum(-1) flattened_targets = batch["tokens"].masked_select(target_mask) # with autocast(enabled=False): # avoid mixed_precision in criterion with torch.backends.cudnn.flags(enabled=False): loss = criterion( outputs["model_outputs"].transpose(0, 1), # [B, T, C] -> [T, B, C] flattened_targets, batch["feature_lengths"], target_lengths, ) wer, target_texts, pred_texts = compute_wer( outputs["model_outputs"].detach().cpu().numpy(), batch["tokens"].cpu().numpy(), batch["feature_lengths"].cpu().numpy(), batch["token_lengths"].cpu().numpy(), self.tokenizer, ) # for visualization outputs["target_texts"] = target_texts outputs["pred_texts"] = pred_texts loss_dict = {"loss": loss, "wer": wer} return outputs, loss_dict def eval_step(self, batch, criterion): return self.train_step(batch, criterion) def train_log(self, batch, outputs, dashboard_logger, assets, step, prefix="Train - "): pred_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["pred_texts"])]) target_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["target_texts"])]) dashboard_logger.add_text(f"{prefix} - Last Batch Predictions:", f"
{pred_log_text}", step) dashboard_logger.add_text(f"{prefix} - Last Batch Target:", f"
{target_log_text}", step) def eval_log(self, batch, outputs, dashboard_logger, assets, step): self.train_log(batch, outputs, dashboard_logger, assets, step, prefix="Eval - ") def get_criterion(self): ctc_loss = nn.CTCLoss(blank=self.config.vocabulary["[PAD]"], reduction="mean") return ctc_loss def format_batch(self, batch): batch["features"] = self.reshape_input(batch["features"]) batch["feature_lengths"] = self.recompute_input_lengths(batch["feature_lengths"]).long() assert ( batch["feature_lengths"].max() == batch["features"].shape[1] ), f"{batch['feature_lengths'].max()} vs {batch['features'].shape[1]}" return batch def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: return None if __name__ == "__main__": args = DeepSpeechArgs(n_input=23, n_tokens=13) model = DeepSpeech(args) input_tensor = torch.rand(2, 57, 23) input_tensor = model.reshape_input(input_tensor) outputs = model.forward(input_tensor, None) print(outputs["model_outputs"].shape)