From fcfecf63105df073d9d5ce71feff10195c7991e0 Mon Sep 17 00:00:00 2001 From: Jindrich Matousek Date: Thu, 9 Mar 2023 16:32:29 +0100 Subject: [PATCH] Fix usage of `aux_input["min_input_length"]` when running `test_run()` during training --- TTS/tts/models/vits.py | 2 +- TTS/tts/utils/synthesis.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 28f91a73..a6dbf2e0 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1175,7 +1175,7 @@ class Vits(BaseTTS): ) # JMa: set minimum duration if required # w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask - w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask + w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask # JMa: length scale for the given sentence-like input length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale w *= length_scale diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 30e1dcb0..59897452 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -30,7 +30,7 @@ def run_model_torch( style_text: str = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, - aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0}, + aux_input: Dict = {}, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -49,9 +49,9 @@ def run_model_torch( _func = model.module.inference else: _func = model.inference - # JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input` - # to enable changing length (durations) per each input text (sentence) and to set - # minimum allowed length of each input char/phoneme + # JMa: propagate other inputs like `durations``, `length_scale``, and `min_input_length` + # to `aux_input` to enable changing length (durations) per each input text (sentence) + # and to set minimum allowed length of each input char/phoneme # - `length_scale` changes length of the whole generated wav # - `durations` sets up duration (in frames) for each input text ID # - minimum allowed length (in frames) per input ID (char/phoneme) during inference @@ -114,7 +114,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs -# JMa: add `aux_input` to enable extra input (length_scale, durations) +# JMa: add `aux_input` to enable extra input (like length_scale, durations) def synthesis( model, text,