Enable ensuring minimum length per token

This commit is contained in:
Jindrich Matousek 2023-03-03 20:41:19 +01:00
parent 8a29d57ff0
commit 687789558e
2 changed files with 59 additions and 20 deletions

View File

@ -1082,12 +1082,47 @@ class Vits(BaseTTS):
return aux_input["x_lengths"] return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device) return torch.tensor(x.shape[1:2]).to(x.device)
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence) # JMa: set minimum duration if predicted duration is lower than threshold
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
# @staticmethod
# def _set_min_inference_length(d, threshold):
# d_mask = d < threshold
# d[d_mask] = threshold
# return d
def _set_min_inference_length(self, x, durs, threshold):
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
# Get list of tokens from IDs
tokens = x.squeeze().tolist()
# Check current and next token
n = self.tokenizer.characters.id_to_char(tokens[0])
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
for ix, idx in enumerate(tokens[1:]):
# c = self.tokenizer.characters.id_to_char(id_c)
c = n
n = self.tokenizer.characters.id_to_char(idx)
if c in punctlike:
# Skip thresholding for punctuation
continue
# Add duration from next punctuation if possible
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
# Threshold duration if duration lower than threshold
if d < threshold:
durs[:,:,ix] = threshold
return durs
@torch.no_grad() @torch.no_grad()
def inference( def inference(
self, self,
x, x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None}, aux_input={"x_lengths": None,
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"durations": None,
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
},
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Note: Note:
@ -1107,6 +1142,9 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]` - m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]`
""" """
# JMa: Save input
x_input = x
sid, g, lid, durations = self._set_cond_input(aux_input) sid, g, lid, durations = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input) x_lengths = self._set_x_lengths(x, aux_input)
@ -1135,9 +1173,13 @@ class Vits(BaseTTS):
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
) )
# JMa: set minimum duration if required
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
# JMa: length scale for the given sentence-like input # JMa: length scale for the given sentence-like input
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
w = torch.exp(logw) * x_mask * length_scale w *= length_scale
# w = torch.exp(logw) * x_mask * length_scale
else: else:
assert durations.shape[-1] == x.shape[-1] assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0) w = durations.unsqueeze(0)

View File

@ -30,7 +30,7 @@ def run_model_torch(
style_text: str = None, style_text: str = None,
d_vector: torch.Tensor = None, d_vector: torch.Tensor = None,
language_id: torch.Tensor = None, language_id: torch.Tensor = None,
aux_input: Dict = {}, aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0},
) -> Dict: ) -> Dict:
"""Run a torch model for inference. It does not support batch inference. """Run a torch model for inference. It does not support batch inference.
@ -49,22 +49,19 @@ def run_model_torch(
_func = model.module.inference _func = model.module.inference
else: else:
_func = model.inference _func = model.inference
outputs = _func( # JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input`
inputs, # to enable changing length (durations) per each input text (sentence) and to set
aux_input={ # minimum allowed length of each input char/phoneme
"x_lengths": input_lengths, # - `length_scale` changes length of the whole generated wav
"speaker_ids": speaker_id, # - `durations` sets up duration (in frames) for each input text ID
"d_vectors": d_vector, # - minimum allowed length (in frames) per input ID (char/phoneme) during inference
"style_mel": style_mel, aux_input["x_lengths"] = input_lengths
"style_text": style_text, aux_input["speaker_ids"] = speaker_id
"language_ids": language_id, aux_input["d_vectors"] = d_vector
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence) aux_input["style_mel"] = style_mel
# - `length_scale` changes length of the whole generated wav aux_input["style_text"] = style_text
# - `durations` sets up duration (in frames) for each input text ID aux_input["language_ids"] = language_id
"durations": aux_input.get("durations", None), outputs = _func(inputs, aux_input)
"length_scale": aux_input.get("length_scale", None),
},
)
return outputs return outputs