Fix length scale handling and default value

This commit is contained in:
Jindrich Matousek 2023-03-13 21:13:51 +01:00
parent fcfecf6310
commit 67edc4e40f
1 changed files with 2 additions and 2 deletions

View File

@ -1120,7 +1120,7 @@ class Vits(BaseTTS):
"speaker_ids": None, "speaker_ids": None,
"language_ids": None, "language_ids": None,
"durations": None, "durations": None,
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) "length_scale": 1.0, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length` "min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
}, },
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
@ -1177,7 +1177,7 @@ class Vits(BaseTTS):
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask # w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
# JMa: length scale for the given sentence-like input # JMa: length scale for the given sentence-like input
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale length_scale = aux_input.get("length_scale", self.length_scale)
w *= length_scale w *= length_scale
# w = torch.exp(logw) * x_mask * length_scale # w = torch.exp(logw) * x_mask * length_scale
else: else: