From abca11714e6f2996fe037cf1ebc4e45e081e338e Mon Sep 17 00:00:00 2001 From: Jindrich Matousek Date: Thu, 23 Feb 2023 18:51:56 +0100 Subject: [PATCH] Support external durations and input text (sentence) based lenght scale in VITS Add aux_input to propagate user parameters to inference --- TTS/tts/models/vits.py | 7 +++++-- TTS/tts/utils/synthesis.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 14c76add..c52fde2e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1082,11 +1082,12 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) @torch.no_grad() def inference( self, x, - aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None}, ): # pylint: disable=dangerous-default-value """ Note: @@ -1134,7 +1135,9 @@ class Vits(BaseTTS): logw = self.duration_predictor( x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb ) - w = torch.exp(logw) * x_mask * self.length_scale + # JMa: length scale for the given sentence-like input + length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale + w = torch.exp(logw) * x_mask * length_scale else: assert durations.shape[-1] == x.shape[-1] w = durations.unsqueeze(0) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 039816db..63a348d0 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -21,6 +21,7 @@ def compute_style_mel(style_wav, ap, cuda=False): return style_mel +# JMa: add `aux_input` to enable extra input (length_scale, durations) def run_model_torch( model: nn.Module, inputs: torch.Tensor, @@ -29,6 +30,7 @@ def run_model_torch( style_text: str = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, + aux_input: Dict = {}, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -56,6 +58,11 @@ def run_model_torch( "style_mel": style_mel, "style_text": style_text, "language_ids": language_id, + # JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence) + # - `length_scale` changes length of the whole generated wav + # - `durations` sets up duration (in frames) for each input text ID + "durations": aux_input.get("durations", None), + "length_scale": aux_input.get("length_scale", None), }, ) return outputs @@ -110,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs +# JMa: add `aux_input` to enable extra input (length_scale, durations) def synthesis( model, text, @@ -122,6 +130,7 @@ def synthesis( do_trim_silence=False, d_vector=None, language_id=None, + aux_input={}, ): """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to the vocoder model. @@ -218,10 +227,14 @@ def synthesis( style_text, d_vector=d_vector, language_id=language_id, + # JMa: add `aux_input` to enable extra input (length_scale, durations) + aux_input=aux_input, ) model_outputs = outputs["model_outputs"] model_outputs = model_outputs[0].data.cpu().numpy() alignments = outputs["alignments"] + # JMa: extract durations + durations = outputs.get("durations", None) # convert outputs to numpy # plot results @@ -240,6 +253,8 @@ def synthesis( "alignments": alignments, "text_inputs": text_inputs, "outputs": outputs, + # JMa: return durations + "durations": durations, } return return_dict