Implement DeepSpeech

This commit is contained in:
Eren Gölge 2021-09-30 14:39:35 +00:00
parent 89cbfbc829
commit d2323f0d98
1 changed files with 214 additions and 0 deletions

View File

@ -0,0 +1,214 @@
from dataclasses import dataclass
from typing import Union
import jiwer
import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as f
from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.stt.models.base_stt import BaseSTT
from TTS.tts.utils.helpers import sequence_mask
def compute_wer(logits, targets, logits_lenghts, target_lengths, tokenizer):
pred_ids = np.argmax(logits, axis=-1)
pred_texts = []
for idx, pi in enumerate(pred_ids):
pi = pi[: logits_lenghts[idx]]
tokens = tokenizer.decode(pi)
text = tokenizer.tokens_to_string(tokens)
pred_texts.append("".join(text))
label_texts = []
for idx, pi in enumerate(targets):
label = "".join(tokenizer.decode(pi))
label = label[: target_lengths[idx]]
label = label.replace(tokenizer.word_del_token, " ")
label_texts.append(label)
wer = jiwer.wer(label_texts, pred_texts)
return wer, label_texts, pred_texts
@dataclass
class DeepSpeechArgs(Coqpit):
n_tokens: int = 0 # The number of characters and +1 for CTC blank label
n_context: int = (
9 # The number of frames in the context - translates to the windows size of 2 * n_context + n_input.
)
context_step: int = 1 # The number of steps take between each context window.
n_input: int = 26 # Number of MFCC features - TODO: Determine this programmatically from the sample rate
n_hidden_1: int = 2048
dropout_1: float = 0.1
n_hidden_2: int = 2048
dropout_2: float = 0.1
n_hidden_3: int = 2048 # The number of units in the third layer, which feeds in to the LSTM
dropout_3: float = 0.1
n_hidden_4: int = 2048
n_hidden_5: int = 2048
dropout_5: float = 0.1
layer_norm: bool = False # TODO: add layer norm layers to the model
relu_clip: float = 20.0
class ClippedReLU(nn.Module):
def __init__(self, clip=20.0):
super().__init__()
self.clip = clip
def forward(self, x):
return torch.clip(f.relu(x), max=self.clip)
class DeepSpeech(BaseSTT):
def __init__(self, config: Union["DeepSpeechConfig", DeepSpeechArgs]):
super().__init__(config)
if hasattr(self, "config"):
self.tokenizer = Tokenizer(vocab_dict=self.config.vocabulary)
self.layer_1 = nn.Linear((self.args.n_context * 2 + 1) * self.args.n_input, self.args.n_hidden_1)
self.dropout_1 = nn.Dropout(p=self.args.dropout_1)
self.relu_1 = ClippedReLU(self.args.relu_clip)
self.layer_2 = nn.Linear(self.args.n_hidden_1, self.args.n_hidden_2)
self.dropout_2 = nn.Dropout(p=self.args.dropout_2)
self.relu_2 = ClippedReLU(self.args.relu_clip)
self.layer_3 = nn.Linear(self.args.n_hidden_2, self.args.n_hidden_3)
self.dropout_3 = nn.Dropout(p=self.args.dropout_3)
self.relu_3 = ClippedReLU(self.args.relu_clip)
self.lstm = nn.LSTM(self.args.n_hidden_3, self.args.n_hidden_4, num_layers=1, batch_first=True)
self.layer_5 = nn.Linear(self.args.n_hidden_4, self.args.n_hidden_5)
self.dropout_5 = nn.Dropout(p=self.args.dropout_5)
self.relu_5 = ClippedReLU(self.args.relu_clip)
self.layer_6 = nn.Linear(self.args.n_hidden_5, self.args.n_tokens)
@staticmethod
def logits_to_text(logits, tokenizer):
pred_ids = np.argmax(logits, axis=-1)
pred_texts = []
for pi in pred_ids:
pred_texts.append("".join(tokenizer.decode(pi)))
return pred_texts
def reshape_input(self, x: torch.Tensor) -> torch.Tensor:
"""Format the input tensor to define the context windows.
Args:
x (Tensor): input tensor.
Returns:
Tensor: formatted tensor with the last dimension is the context window.
Shapes:
x : :math:`[B, T, C]`
Tensor: :math:`[B, T/context_step, C*2*n_context+C]`
"""
# TODO: might need to pad x before unfold for the last window
x = x.unfold(1, 2 * self.args.n_context + 1, self.args.context_step).permute(0, 1, 3, 2)
x = x.reshape(x.shape[0], -1, self.args.n_input * 2 * self.args.n_context + self.args.n_input)
return x
def recompute_input_lengths(self, x):
return ((x - (2 * self.args.n_context + 1)) / self.args.context_step) + 1
def forward(self, x, previous_states: tuple[torch.Tensor, torch.Tensor]):
x = self.layer_1(x)
x = self.dropout_1(x)
x = self.relu_1(x)
x = self.layer_2(x)
x = self.dropout_2(x)
x = self.relu_2(x)
x = self.layer_3(x)
x = self.dropout_3(x)
x = self.relu_3(x)
new_state_h, new_state_c = self.lstm(x, previous_states)
x = self.layer_5(new_state_h)
x = self.dropout_5(x)
x = self.relu_5(x)
x = self.layer_6(x)
x = x.log_softmax(dim=-1)
return {"model_outputs": x, "model_states": (new_state_h, new_state_c)}
@torch.no_grad()
def inference(self, x, aux_inputs={"previous_states": None}):
return self.forward(x, aux_inputs["previous_states"])
def train_step(self, batch, criterion):
outputs = self.forward(batch["features"], None)
if self.config.loss_masking:
output_mask = sequence_mask(batch["feature_lengths"])
outputs["model_outputs"] = outputs["model_outputs"].masked_fill(
~output_mask.unsqueeze(-1), self.tokenizer.pad_token_id
)
target_mask = sequence_mask(batch["token_lengths"])
target_lengths = target_mask.sum(-1)
flattened_targets = batch["tokens"].masked_select(target_mask)
# with autocast(enabled=False): # avoid mixed_precision in criterion
with torch.backends.cudnn.flags(enabled=False):
loss = criterion(
outputs["model_outputs"].transpose(0, 1), # [B, T, C] -> [T, B, C]
flattened_targets,
batch["feature_lengths"],
target_lengths,
)
wer, target_texts, pred_texts = compute_wer(
outputs["model_outputs"].detach().cpu().numpy(),
batch["tokens"].cpu().numpy(),
batch["feature_lengths"].cpu().numpy(),
batch["token_lengths"].cpu().numpy(),
self.tokenizer,
)
# for visualization
outputs["target_texts"] = target_texts
outputs["pred_texts"] = pred_texts
loss_dict = {"loss": loss, "wer": wer}
return outputs, loss_dict
def eval_step(self, batch, criterion):
return self.train_step(batch, criterion)
def train_log(self, batch, outputs, dashboard_logger, assets, step, prefix="Train - "):
pred_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["pred_texts"])])
target_log_text = "\n".join([f"{idx} - {pred}" for idx, pred in enumerate(outputs["target_texts"])])
dashboard_logger.add_text(f"{prefix} - Last Batch Predictions:", f"<pre>{pred_log_text}</pre>", step)
dashboard_logger.add_text(f"{prefix} - Last Batch Target:", f"<pre>{target_log_text}</pre>", step)
def eval_log(self, batch, outputs, dashboard_logger, assets, step):
self.train_log(batch, outputs, dashboard_logger, assets, step, prefix="Eval - ")
def get_criterion(self):
ctc_loss = nn.CTCLoss(blank=self.config.vocabulary["[PAD]"], reduction="mean")
return ctc_loss
def format_batch(self, batch):
batch["features"] = self.reshape_input(batch["features"])
batch["feature_lengths"] = self.recompute_input_lengths(batch["feature_lengths"]).long()
assert (
batch["feature_lengths"].max() == batch["features"].shape[1]
), f"{batch['feature_lengths'].max()} vs {batch['features'].shape[1]}"
return batch
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
return None
if __name__ == "__main__":
args = DeepSpeechArgs(n_input=23, n_tokens=13)
model = DeepSpeech(args)
input_tensor = torch.rand(2, 57, 23)
input_tensor = model.reshape_input(input_tensor)
outputs = model.forward(input_tensor, None)
print(outputs["model_outputs"].shape)