Support external durations and input text (sentence) based lenght scale in VITS

Add aux_input to propagate user parameters to inference
This commit is contained in:
Jindrich Matousek 2023-02-23 18:51:56 +01:00
parent cbdec704dc
commit abca11714e
2 changed files with 20 additions and 2 deletions

View File

@ -1082,11 +1082,12 @@ class Vits(BaseTTS):
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
@torch.no_grad()
def inference(
self,
x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1134,7 +1135,9 @@ class Vits(BaseTTS):
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale
# JMa: length scale for the given sentence-like input
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
w = torch.exp(logw) * x_mask * length_scale
else:
assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0)

View File

@ -21,6 +21,7 @@ def compute_style_mel(style_wav, ap, cuda=False):
return style_mel
# JMa: add `aux_input` to enable extra input (length_scale, durations)
def run_model_torch(
model: nn.Module,
inputs: torch.Tensor,
@ -29,6 +30,7 @@ def run_model_torch(
style_text: str = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
aux_input: Dict = {},
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -56,6 +58,11 @@ def run_model_torch(
"style_mel": style_mel,
"style_text": style_text,
"language_ids": language_id,
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
# - `length_scale` changes length of the whole generated wav
# - `durations` sets up duration (in frames) for each input text ID
"durations": aux_input.get("durations", None),
"length_scale": aux_input.get("length_scale", None),
},
)
return outputs
@ -110,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
return wavs
# JMa: add `aux_input` to enable extra input (length_scale, durations)
def synthesis(
model,
text,
@ -122,6 +130,7 @@ def synthesis(
do_trim_silence=False,
d_vector=None,
language_id=None,
aux_input={},
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -218,10 +227,14 @@ def synthesis(
style_text,
d_vector=d_vector,
language_id=language_id,
# JMa: add `aux_input` to enable extra input (length_scale, durations)
aux_input=aux_input,
)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]
# JMa: extract durations
durations = outputs.get("durations", None)
# convert outputs to numpy
# plot results
@ -240,6 +253,8 @@ def synthesis(
"alignments": alignments,
"text_inputs": text_inputs,
"outputs": outputs,
# JMa: return durations
"durations": durations,
}
return return_dict