From 687789558e151c07b022b20506dd4a4ec00401c0 Mon Sep 17 00:00:00 2001 From: Jindrich Matousek Date: Fri, 3 Mar 2023 20:41:19 +0100 Subject: [PATCH] Enable ensuring minimum length per token --- TTS/tts/models/vits.py | 48 +++++++++++++++++++++++++++++++++++--- TTS/tts/utils/synthesis.py | 31 +++++++++++------------- 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index ad8e65eb..28f91a73 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1081,13 +1081,48 @@ class Vits(BaseTTS): if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + + # JMa: set minimum duration if predicted duration is lower than threshold + # Workaround to avoid short durations that cause some chars/phonemes to be reduced + # @staticmethod + # def _set_min_inference_length(d, threshold): + # d_mask = d < threshold + # d[d_mask] = threshold + # return d + + def _set_min_inference_length(self, x, durs, threshold): + punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank] + # Get list of tokens from IDs + tokens = x.squeeze().tolist() + # Check current and next token + n = self.tokenizer.characters.id_to_char(tokens[0]) + # for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])): + for ix, idx in enumerate(tokens[1:]): + # c = self.tokenizer.characters.id_to_char(id_c) + c = n + n = self.tokenizer.characters.id_to_char(idx) + if c in punctlike: + # Skip thresholding for punctuation + continue + # Add duration from next punctuation if possible + d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix] + # Threshold duration if duration lower than threshold + if d < threshold: + durs[:,:,ix] = threshold + return durs - # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) @torch.no_grad() def inference( self, x, - aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None}, + aux_input={"x_lengths": None, + "d_vectors": None, + "speaker_ids": None, + "language_ids": None, + "durations": None, + "length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) + "min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length` + }, ): # pylint: disable=dangerous-default-value """ Note: @@ -1107,6 +1142,9 @@ class Vits(BaseTTS): - m_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]` """ + # JMa: Save input + x_input = x + sid, g, lid, durations = self._set_cond_input(aux_input) x_lengths = self._set_x_lengths(x, aux_input) @@ -1135,9 +1173,13 @@ class Vits(BaseTTS): logw = self.duration_predictor( x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb ) + # JMa: set minimum duration if required + # w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask + w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask # JMa: length scale for the given sentence-like input length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale - w = torch.exp(logw) * x_mask * length_scale + w *= length_scale + # w = torch.exp(logw) * x_mask * length_scale else: assert durations.shape[-1] == x.shape[-1] w = durations.unsqueeze(0) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 63a348d0..30e1dcb0 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -30,7 +30,7 @@ def run_model_torch( style_text: str = None, d_vector: torch.Tensor = None, language_id: torch.Tensor = None, - aux_input: Dict = {}, + aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0}, ) -> Dict: """Run a torch model for inference. It does not support batch inference. @@ -49,22 +49,19 @@ def run_model_torch( _func = model.module.inference else: _func = model.inference - outputs = _func( - inputs, - aux_input={ - "x_lengths": input_lengths, - "speaker_ids": speaker_id, - "d_vectors": d_vector, - "style_mel": style_mel, - "style_text": style_text, - "language_ids": language_id, - # JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence) - # - `length_scale` changes length of the whole generated wav - # - `durations` sets up duration (in frames) for each input text ID - "durations": aux_input.get("durations", None), - "length_scale": aux_input.get("length_scale", None), - }, - ) + # JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input` + # to enable changing length (durations) per each input text (sentence) and to set + # minimum allowed length of each input char/phoneme + # - `length_scale` changes length of the whole generated wav + # - `durations` sets up duration (in frames) for each input text ID + # - minimum allowed length (in frames) per input ID (char/phoneme) during inference + aux_input["x_lengths"] = input_lengths + aux_input["speaker_ids"] = speaker_id + aux_input["d_vectors"] = d_vector + aux_input["style_mel"] = style_mel + aux_input["style_text"] = style_text + aux_input["language_ids"] = language_id + outputs = _func(inputs, aux_input) return outputs