Add support for phone (char) based length scale

Remove length_scale from default aux_input
This commit is contained in:
Jindrich Matousek 2023-08-06 13:17:53 +02:00
parent d3661d7d26
commit 874143bf04
2 changed files with 20 additions and 4 deletions

View File

@ -1123,7 +1123,6 @@ class Vits(BaseTTS):
"speaker_ids": None,
"language_ids": None,
"durations": None,
"length_scale": 1.0, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
},
): # pylint: disable=dangerous-default-value
@ -1136,6 +1135,8 @@ class Vits(BaseTTS):
- x_lengths: :math:`[B]`
- d_vectors: :math:`[B, C]`
- speaker_ids: :math:`[B]`
- durations: :math: `[B, T_seq]`
- length_scale: :math: `[B, T_seq]` or `[B]`
Return Shapes:
- model_outputs: :math:`[B, 1, T_wav]`
@ -1177,13 +1178,27 @@ class Vits(BaseTTS):
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
# JMa: set minimum duration if required
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask
# JMa: length scale for the given sentence-like input
# ORIG: w = torch.exp(logw) * x_mask * self.length_scale
# If `length_scale` is in `aux_input`, it resets the default value given by `self.length_scale`,
# otherwise the default `self.length_scale` from `config.json` is used.
length_scale = aux_input.get("length_scale", self.length_scale)
w *= length_scale
# w = torch.exp(logw) * x_mask * length_scale
# JMa: `length_scale` is used to scale duration relatively to the predicted values, it should be:
# - float (or int) => duration of the output speech will be linearly scaled
# - torch vector `[B, T_seq]`` (`B`` is batch size, `T_seq`` is the length of the input symbols)
# => each input symbol (phone or char) is scaled according to the corresponding value in the torch vector
if isinstance(length_scale, float) or isinstance(length_scale, int):
w *= length_scale
else:
assert length_scale.shape[-1] == w.shape[-1]
w *= length_scale.unsqueeze(0)
else:
# To force absolute durations (in frames), "durations" has to be in `aux_input`.
# The durations should be a torch vector [B, N] (`B`` is batch size, `T_seq`` is the length of the input symbols)
# => each input symbol (phone or char) will have the duration given by the corresponding value (number of frames) in the torch vector
assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0)

View File

@ -215,6 +215,7 @@ def synthesis(
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(
model,