mirror of https://github.com/coqui-ai/TTS.git
Fix usage of `aux_input["min_input_length"]` when running `test_run()` during training
This commit is contained in:
parent
687789558e
commit
fcfecf6310
|
@ -1175,7 +1175,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
# JMa: set minimum duration if required
|
# JMa: set minimum duration if required
|
||||||
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||||
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
|
||||||
# JMa: length scale for the given sentence-like input
|
# JMa: length scale for the given sentence-like input
|
||||||
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
||||||
w *= length_scale
|
w *= length_scale
|
||||||
|
|
|
@ -30,7 +30,7 @@ def run_model_torch(
|
||||||
style_text: str = None,
|
style_text: str = None,
|
||||||
d_vector: torch.Tensor = None,
|
d_vector: torch.Tensor = None,
|
||||||
language_id: torch.Tensor = None,
|
language_id: torch.Tensor = None,
|
||||||
aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0},
|
aux_input: Dict = {},
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Run a torch model for inference. It does not support batch inference.
|
"""Run a torch model for inference. It does not support batch inference.
|
||||||
|
|
||||||
|
@ -49,9 +49,9 @@ def run_model_torch(
|
||||||
_func = model.module.inference
|
_func = model.module.inference
|
||||||
else:
|
else:
|
||||||
_func = model.inference
|
_func = model.inference
|
||||||
# JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input`
|
# JMa: propagate other inputs like `durations``, `length_scale``, and `min_input_length`
|
||||||
# to enable changing length (durations) per each input text (sentence) and to set
|
# to `aux_input` to enable changing length (durations) per each input text (sentence)
|
||||||
# minimum allowed length of each input char/phoneme
|
# and to set minimum allowed length of each input char/phoneme
|
||||||
# - `length_scale` changes length of the whole generated wav
|
# - `length_scale` changes length of the whole generated wav
|
||||||
# - `durations` sets up duration (in frames) for each input text ID
|
# - `durations` sets up duration (in frames) for each input text ID
|
||||||
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
|
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
|
||||||
|
@ -114,7 +114,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
||||||
return wavs
|
return wavs
|
||||||
|
|
||||||
|
|
||||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
# JMa: add `aux_input` to enable extra input (like length_scale, durations)
|
||||||
def synthesis(
|
def synthesis(
|
||||||
model,
|
model,
|
||||||
text,
|
text,
|
||||||
|
|
Loading…
Reference in New Issue