coqui-tts/TTS/stt/models/base_stt.py

234 lines
8.8 KiB
Python

import os
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.stt.datasets.dataset import STTDataset
from TTS.stt.datasets.tokenizer import Tokenizer
from TTS.tts.datasets import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
# pylint: skip-file
class BaseSTT(BaseModel):
"""Abstract `stt` class. Every new `stt` model must inherit this.
It defines `stt` specific functions on top of `Model`.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
If the config is for the model with a name like "*Args", then we assign the directly.
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
if "vocabulary" in config and config.vocabulary is not None:
# loading from DeepSpeechConfig
self.config = config
self.args = config.model_args
self.args.n_tokens = len(self.config.vocabulary)
else:
# loading from DeepSpeechArgs
self.config = config
self.args = config.model_args
else:
self.args = config
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = parse_symbols()
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
...
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
You must override this if you use a custom dataset.
Args:
batch (Dict): [description]
Returns:
Dict: [description]
"""
# setup input batch
text_input = batch["text"]
text_lengths = batch["text_lengths"]
speaker_names = batch["speaker_names"]
linear_input = batch["linear"]
mel_input = batch["mel"]
mel_lengths = batch["mel_lengths"]
stop_targets = batch["stop_targets"]
item_idx = batch["item_idxs"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# compute durations from attention masks
durations = None
if attn_mask is not None:
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
dur[c_idxs] = counts
# smooth the durations and set any 0 duration to 1
# by cutting off from the largest duration indeces.
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# set stop targets wrt reduction factor
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
return {
"text_input": text_input,
"text_lengths": text_lengths,
"speaker_names": speaker_names,
"mel_input": mel_input,
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"stop_target_lengths": stop_target_lengths,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,
"d_vectors": d_vectors,
"max_text_length": float(max_text_length),
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
}
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: bool,
data_items: List,
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
else:
ap = assets["audio_processor"]
tokenizer = assets["tokenizer"]
# init dataset
dataset = STTDataset(
samples=data_items,
ap=ap,
tokenizer=tokenizer,
batch_group_size=config.batch_group_size,
sort_by_audio_len=config.sort_by_audio_len,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
verbose=verbose,
feature_extractor=config.feature_extractor,
)
# halt DDP processes for the main process to finish computing the phoneme cache
if num_gpus > 1:
dist.barrier()
# sampler for DDP
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
# init dataloader
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
# def test_run(self, ap) -> Tuple[Dict, Dict]:
# """Generic test run for `tts` models used by `Trainer`.
# You can override this for a different behaviour.
# Returns:
# Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
# """
# print(" | > Synthesizing test sentences.")
# test_audios = {}
# test_figures = {}
# test_sentences = self.config.test_sentences
# aux_inputs = self.get_aux_input()
# for idx, sen in enumerate(test_sentences):
# outputs_dict = synthesis(
# self,
# sen,
# self.config,
# "cuda" in str(next(self.parameters()).device),
# ap,
# speaker_id=aux_inputs["speaker_id"],
# d_vector=aux_inputs["d_vector"],
# style_wav=aux_inputs["style_wav"],
# enable_eos_bos_chars=self.config.enable_eos_bos_chars,
# use_griffin_lim=True,
# do_trim_silence=False,
# )
# test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
# test_figures["{}-prediction".format(idx)] = plot_spectrogram(
# outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
# )
# test_figures["{}-alignment".format(idx)] = plot_alignment(
# outputs_dict["outputs"]["alignments"], output_fig=False
# )
# return test_figures, test_audios