Support external durations and input text (sentence) based lenght scale in VITS

Add aux_input to propagate user parameters to inference
This commit is contained in:
Jindrich Matousek 2023-02-23 18:51:56 +01:00
parent cbdec704dc
commit abca11714e
2 changed files with 20 additions and 2 deletions

View File

@ -1082,11 +1082,12 @@ class Vits(BaseTTS):
return aux_input["x_lengths"] return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device) return torch.tensor(x.shape[1:2]).to(x.device)
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
@torch.no_grad() @torch.no_grad()
def inference( def inference(
self, self,
x, x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Note: Note:
@ -1134,7 +1135,9 @@ class Vits(BaseTTS):
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
) )
w = torch.exp(logw) * x_mask * self.length_scale # JMa: length scale for the given sentence-like input
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
w = torch.exp(logw) * x_mask * length_scale
else: else:
assert durations.shape[-1] == x.shape[-1] assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0) w = durations.unsqueeze(0)

View File

@ -21,6 +21,7 @@ def compute_style_mel(style_wav, ap, cuda=False):
return style_mel return style_mel
# JMa: add `aux_input` to enable extra input (length_scale, durations)
def run_model_torch( def run_model_torch(
model: nn.Module, model: nn.Module,
inputs: torch.Tensor, inputs: torch.Tensor,
@ -29,6 +30,7 @@ def run_model_torch(
style_text: str = None, style_text: str = None,
d_vector: torch.Tensor = None, d_vector: torch.Tensor = None,
language_id: torch.Tensor = None, language_id: torch.Tensor = None,
aux_input: Dict = {},
) -> Dict: ) -> Dict:
"""Run a torch model for inference. It does not support batch inference. """Run a torch model for inference. It does not support batch inference.
@ -56,6 +58,11 @@ def run_model_torch(
"style_mel": style_mel, "style_mel": style_mel,
"style_text": style_text, "style_text": style_text,
"language_ids": language_id, "language_ids": language_id,
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
# - `length_scale` changes length of the whole generated wav
# - `durations` sets up duration (in frames) for each input text ID
"durations": aux_input.get("durations", None),
"length_scale": aux_input.get("length_scale", None),
}, },
) )
return outputs return outputs
@ -110,6 +117,7 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
return wavs return wavs
# JMa: add `aux_input` to enable extra input (length_scale, durations)
def synthesis( def synthesis(
model, model,
text, text,
@ -122,6 +130,7 @@ def synthesis(
do_trim_silence=False, do_trim_silence=False,
d_vector=None, d_vector=None,
language_id=None, language_id=None,
aux_input={},
): ):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model. the vocoder model.
@ -218,10 +227,14 @@ def synthesis(
style_text, style_text,
d_vector=d_vector, d_vector=d_vector,
language_id=language_id, language_id=language_id,
# JMa: add `aux_input` to enable extra input (length_scale, durations)
aux_input=aux_input,
) )
model_outputs = outputs["model_outputs"] model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy() model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"] alignments = outputs["alignments"]
# JMa: extract durations
durations = outputs.get("durations", None)
# convert outputs to numpy # convert outputs to numpy
# plot results # plot results
@ -240,6 +253,8 @@ def synthesis(
"alignments": alignments, "alignments": alignments,
"text_inputs": text_inputs, "text_inputs": text_inputs,
"outputs": outputs, "outputs": outputs,
# JMa: return durations
"durations": durations,
} }
return return_dict return return_dict