mirror of https://github.com/coqui-ai/TTS.git
Support external durations and input text (sentence) based lenght scale in VITS
Add aux_input to propagate user parameters to inference
This commit is contained in:
parent
cbdec704dc
commit
abca11714e
|
@ -1082,11 +1082,12 @@ class Vits(BaseTTS):
|
|||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1134,7 +1135,9 @@ class Vits(BaseTTS):
|
|||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
# JMa: length scale for the given sentence-like input
|
||||
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
else:
|
||||
assert durations.shape[-1] == x.shape[-1]
|
||||
w = durations.unsqueeze(0)
|
||||
|
|
|
@ -21,6 +21,7 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
return style_mel
|
||||
|
||||
|
||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
||||
def run_model_torch(
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
|
@ -29,6 +30,7 @@ def run_model_torch(
|
|||
style_text: str = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
aux_input: Dict = {},
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -56,6 +58,11 @@ def run_model_torch(
|
|||
"style_mel": style_mel,
|
||||
"style_text": style_text,
|
||||
"language_ids": language_id,
|
||||
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
|
||||
# - `length_scale` changes length of the whole generated wav
|
||||
# - `durations` sets up duration (in frames) for each input text ID
|
||||
"durations": aux_input.get("durations", None),
|
||||
"length_scale": aux_input.get("length_scale", None),
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
@ -110,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
|||
return wavs
|
||||
|
||||
|
||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
||||
def synthesis(
|
||||
model,
|
||||
text,
|
||||
|
@ -122,6 +130,7 @@ def synthesis(
|
|||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
language_id=None,
|
||||
aux_input={},
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
the vocoder model.
|
||||
|
@ -218,10 +227,14 @@ def synthesis(
|
|||
style_text,
|
||||
d_vector=d_vector,
|
||||
language_id=language_id,
|
||||
# JMa: add `aux_input` to enable extra input (length_scale, durations)
|
||||
aux_input=aux_input,
|
||||
)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
# JMa: extract durations
|
||||
durations = outputs.get("durations", None)
|
||||
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
|
@ -240,6 +253,8 @@ def synthesis(
|
|||
"alignments": alignments,
|
||||
"text_inputs": text_inputs,
|
||||
"outputs": outputs,
|
||||
# JMa: return durations
|
||||
"durations": durations,
|
||||
}
|
||||
return return_dict
|
||||
|
||||
|
|
Loading…
Reference in New Issue