diff --git a/TTS/stt/models/base_stt.py b/TTS/stt/models/base_stt.py new file mode 100644 index 00000000..ff286af6 --- /dev/null +++ b/TTS/stt/models/base_stt.py @@ -0,0 +1,233 @@ +import os +from typing import Dict, List, Tuple + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.model import BaseModel +from TTS.stt.datasets.dataset import STTDataset +from TTS.stt.datasets.tokenizer import Tokenizer +from TTS.tts.datasets import TTSDataset +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text import make_symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor + +# pylint: skip-file + + +class BaseSTT(BaseModel): + """Abstract `stt` class. Every new `stt` model must inherit this. + + It defines `stt` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "vocabulary" in config and config.vocabulary is not None: + # loading from DeepSpeechConfig + self.config = config + self.args = config.model_args + self.args.n_tokens = len(self.config.vocabulary) + else: + # loading from DeepSpeechArgs + self.config = config + self.args = config.model_args + else: + self.args = config + + @staticmethod + def get_characters(config: Coqpit) -> str: + # TODO: implement CharacterProcessor + if config.characters is not None: + symbols, phonemes = make_symbols(**config.characters) + else: + from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols + + config.characters = parse_symbols() + model_characters = phonemes if config.use_phonemes else symbols + num_chars = len(model_characters) + getattr(config, "add_blank", False) + return model_characters, config, num_chars + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + ... + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["text"] + text_lengths = batch["text_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + } + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + data_items: List, + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + ap = assets["audio_processor"] + tokenizer = assets["tokenizer"] + + # init dataset + dataset = STTDataset( + samples=data_items, + ap=ap, + tokenizer=tokenizer, + batch_group_size=config.batch_group_size, + sort_by_audio_len=config.sort_by_audio_len, + min_seq_len=config.min_seq_len, + max_seq_len=config.max_seq_len, + verbose=verbose, + feature_extractor=config.feature_extractor, + ) + + # halt DDP processes for the main process to finish computing the phoneme cache + if num_gpus > 1: + dist.barrier() + + # sampler for DDP + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + + # init dataloader + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + # def test_run(self, ap) -> Tuple[Dict, Dict]: + # """Generic test run for `tts` models used by `Trainer`. + + # You can override this for a different behaviour. + + # Returns: + # Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + # """ + # print(" | > Synthesizing test sentences.") + # test_audios = {} + # test_figures = {} + # test_sentences = self.config.test_sentences + # aux_inputs = self.get_aux_input() + # for idx, sen in enumerate(test_sentences): + # outputs_dict = synthesis( + # self, + # sen, + # self.config, + # "cuda" in str(next(self.parameters()).device), + # ap, + # speaker_id=aux_inputs["speaker_id"], + # d_vector=aux_inputs["d_vector"], + # style_wav=aux_inputs["style_wav"], + # enable_eos_bos_chars=self.config.enable_eos_bos_chars, + # use_griffin_lim=True, + # do_trim_silence=False, + # ) + # test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + # test_figures["{}-prediction".format(idx)] = plot_spectrogram( + # outputs_dict["outputs"]["model_outputs"], ap, output_fig=False + # ) + # test_figures["{}-alignment".format(idx)] = plot_alignment( + # outputs_dict["outputs"]["alignments"], output_fig=False + # ) + # return test_figures, test_audios