From 67edc4e40f8c4f2a4d1ce8870c2a96ee1876a458 Mon Sep 17 00:00:00 2001 From: Jindrich Matousek Date: Mon, 13 Mar 2023 21:13:51 +0100 Subject: [PATCH] Fix length scale handling and default value --- TTS/tts/models/vits.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a6dbf2e0..edc6b904 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1120,7 +1120,7 @@ class Vits(BaseTTS): "speaker_ids": None, "language_ids": None, "durations": None, - "length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) + "length_scale": 1.0, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) "min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length` }, ): # pylint: disable=dangerous-default-value @@ -1177,7 +1177,7 @@ class Vits(BaseTTS): # w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask # JMa: length scale for the given sentence-like input - length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale + length_scale = aux_input.get("length_scale", self.length_scale) w *= length_scale # w = torch.exp(logw) * x_mask * length_scale else: