mirror of https://github.com/coqui-ai/TTS.git
Fix length scale handling and default value
This commit is contained in:
parent
fcfecf6310
commit
67edc4e40f
|
@ -1120,7 +1120,7 @@ class Vits(BaseTTS):
|
||||||
"speaker_ids": None,
|
"speaker_ids": None,
|
||||||
"language_ids": None,
|
"language_ids": None,
|
||||||
"durations": None,
|
"durations": None,
|
||||||
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
"length_scale": 1.0, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
||||||
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
|
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
|
||||||
},
|
},
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
|
@ -1177,7 +1177,7 @@ class Vits(BaseTTS):
|
||||||
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||||
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
|
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
|
||||||
# JMa: length scale for the given sentence-like input
|
# JMa: length scale for the given sentence-like input
|
||||||
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
length_scale = aux_input.get("length_scale", self.length_scale)
|
||||||
w *= length_scale
|
w *= length_scale
|
||||||
# w = torch.exp(logw) * x_mask * length_scale
|
# w = torch.exp(logw) * x_mask * length_scale
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue