Enable ensuring minimum length per token

This commit is contained in:
Jindrich Matousek 2023-03-03 20:41:19 +01:00
parent 8a29d57ff0
commit 687789558e
2 changed files with 59 additions and 20 deletions

View File

@ -1081,13 +1081,48 @@ class Vits(BaseTTS):
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
# JMa: set minimum duration if predicted duration is lower than threshold
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
# @staticmethod
# def _set_min_inference_length(d, threshold):
# d_mask = d < threshold
# d[d_mask] = threshold
# return d
def _set_min_inference_length(self, x, durs, threshold):
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
# Get list of tokens from IDs
tokens = x.squeeze().tolist()
# Check current and next token
n = self.tokenizer.characters.id_to_char(tokens[0])
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
for ix, idx in enumerate(tokens[1:]):
# c = self.tokenizer.characters.id_to_char(id_c)
c = n
n = self.tokenizer.characters.id_to_char(idx)
if c in punctlike:
# Skip thresholding for punctuation
continue
# Add duration from next punctuation if possible
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
# Threshold duration if duration lower than threshold
if d < threshold:
durs[:,:,ix] = threshold
return durs
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
@torch.no_grad()
def inference(
self,
x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
aux_input={"x_lengths": None,
"d_vectors": None,
"speaker_ids": None,
"language_ids": None,
"durations": None,
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
},
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1107,6 +1142,9 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
# JMa: Save input
x_input = x
sid, g, lid, durations = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
@ -1135,9 +1173,13 @@ class Vits(BaseTTS):
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
# JMa: set minimum duration if required
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
# JMa: length scale for the given sentence-like input
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
w = torch.exp(logw) * x_mask * length_scale
w *= length_scale
# w = torch.exp(logw) * x_mask * length_scale
else:
assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0)

View File

@ -30,7 +30,7 @@ def run_model_torch(
style_text: str = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
aux_input: Dict = {},
aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0},
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -49,22 +49,19 @@ def run_model_torch(
_func = model.module.inference
else:
_func = model.inference
outputs = _func(
inputs,
aux_input={
"x_lengths": input_lengths,
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"style_text": style_text,
"language_ids": language_id,
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
# - `length_scale` changes length of the whole generated wav
# - `durations` sets up duration (in frames) for each input text ID
"durations": aux_input.get("durations", None),
"length_scale": aux_input.get("length_scale", None),
},
)
# JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input`
# to enable changing length (durations) per each input text (sentence) and to set
# minimum allowed length of each input char/phoneme
# - `length_scale` changes length of the whole generated wav
# - `durations` sets up duration (in frames) for each input text ID
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
aux_input["x_lengths"] = input_lengths
aux_input["speaker_ids"] = speaker_id
aux_input["d_vectors"] = d_vector
aux_input["style_mel"] = style_mel
aux_input["style_text"] = style_text
aux_input["language_ids"] = language_id
outputs = _func(inputs, aux_input)
return outputs