From 874143bf04ceaaa367ff6c4675821ad51aa19ab8 Mon Sep 17 00:00:00 2001 From: Jindrich Matousek Date: Sun, 6 Aug 2023 13:17:53 +0200 Subject: [PATCH] Add support for phone (char) based length scale Remove length_scale from default aux_input --- TTS/tts/models/vits.py | 23 +++++++++++++++++++---- TTS/tts/utils/synthesis.py | 1 + 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 29a0646b..c001412a 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1123,7 +1123,6 @@ class Vits(BaseTTS): "speaker_ids": None, "language_ids": None, "durations": None, - "length_scale": 1.0, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) "min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length` }, ): # pylint: disable=dangerous-default-value @@ -1136,6 +1135,8 @@ class Vits(BaseTTS): - x_lengths: :math:`[B]` - d_vectors: :math:`[B, C]` - speaker_ids: :math:`[B]` + - durations: :math: `[B, T_seq]` + - length_scale: :math: `[B, T_seq]` or `[B]` Return Shapes: - model_outputs: :math:`[B, 1, T_wav]` @@ -1177,13 +1178,27 @@ class Vits(BaseTTS): x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb ) # JMa: set minimum duration if required - # w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input.get("min_input_length", 0) else torch.exp(logw) * x_mask + # JMa: length scale for the given sentence-like input + # ORIG: w = torch.exp(logw) * x_mask * self.length_scale + # If `length_scale` is in `aux_input`, it resets the default value given by `self.length_scale`, + # otherwise the default `self.length_scale` from `config.json` is used. length_scale = aux_input.get("length_scale", self.length_scale) - w *= length_scale - # w = torch.exp(logw) * x_mask * length_scale + # JMa: `length_scale` is used to scale duration relatively to the predicted values, it should be: + # - float (or int) => duration of the output speech will be linearly scaled + # - torch vector `[B, T_seq]`` (`B`` is batch size, `T_seq`` is the length of the input symbols) + # => each input symbol (phone or char) is scaled according to the corresponding value in the torch vector + if isinstance(length_scale, float) or isinstance(length_scale, int): + w *= length_scale + else: + assert length_scale.shape[-1] == w.shape[-1] + w *= length_scale.unsqueeze(0) + else: + # To force absolute durations (in frames), "durations" has to be in `aux_input`. + # The durations should be a torch vector [B, N] (`B`` is batch size, `T_seq`` is the length of the input symbols) + # => each input symbol (phone or char) will have the duration given by the corresponding value (number of frames) in the torch vector assert durations.shape[-1] == x.shape[-1] w = durations.unsqueeze(0) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 59897452..ccdecaf6 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -215,6 +215,7 @@ def synthesis( text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) text_inputs = text_inputs.unsqueeze(0) + # synthesize voice outputs = run_model_torch( model,