mirror of https://github.com/coqui-ai/TTS.git
Implement BaseSTT
This commit is contained in:
parent
d2323f0d98
commit
0c7a2eb948
|
@ -0,0 +1,233 @@
|
|||
import os
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.stt.datasets.dataset import STTDataset
|
||||
from TTS.stt.datasets.tokenizer import Tokenizer
|
||||
from TTS.tts.datasets import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseSTT(BaseModel):
|
||||
"""Abstract `stt` class. Every new `stt` model must inherit this.
|
||||
|
||||
It defines `stt` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
||||
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
if "vocabulary" in config and config.vocabulary is not None:
|
||||
# loading from DeepSpeechConfig
|
||||
self.config = config
|
||||
self.args = config.model_args
|
||||
self.args.n_tokens = len(self.config.vocabulary)
|
||||
else:
|
||||
# loading from DeepSpeechArgs
|
||||
self.config = config
|
||||
self.args = config.model_args
|
||||
else:
|
||||
self.args = config
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = parse_symbols()
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
return model_characters, config, num_chars
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
...
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch["text"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
speaker_names = batch["speaker_names"]
|
||||
linear_input = batch["linear"]
|
||||
mel_input = batch["mel"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
item_idx = batch["item_idxs"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets wrt reduction factor
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"stop_target_lengths": stop_target_lengths,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
ap = assets["audio_processor"]
|
||||
tokenizer = assets["tokenizer"]
|
||||
|
||||
# init dataset
|
||||
dataset = STTDataset(
|
||||
samples=data_items,
|
||||
ap=ap,
|
||||
tokenizer=tokenizer,
|
||||
batch_group_size=config.batch_group_size,
|
||||
sort_by_audio_len=config.sort_by_audio_len,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
verbose=verbose,
|
||||
feature_extractor=config.feature_extractor,
|
||||
)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
|
||||
# init dataloader
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
# def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
# """Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
# You can override this for a different behaviour.
|
||||
|
||||
# Returns:
|
||||
# Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
# """
|
||||
# print(" | > Synthesizing test sentences.")
|
||||
# test_audios = {}
|
||||
# test_figures = {}
|
||||
# test_sentences = self.config.test_sentences
|
||||
# aux_inputs = self.get_aux_input()
|
||||
# for idx, sen in enumerate(test_sentences):
|
||||
# outputs_dict = synthesis(
|
||||
# self,
|
||||
# sen,
|
||||
# self.config,
|
||||
# "cuda" in str(next(self.parameters()).device),
|
||||
# ap,
|
||||
# speaker_id=aux_inputs["speaker_id"],
|
||||
# d_vector=aux_inputs["d_vector"],
|
||||
# style_wav=aux_inputs["style_wav"],
|
||||
# enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
# use_griffin_lim=True,
|
||||
# do_trim_silence=False,
|
||||
# )
|
||||
# test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
# test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
# outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||
# )
|
||||
# test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
# outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
# )
|
||||
# return test_figures, test_audios
|
Loading…
Reference in New Issue