mirror of https://github.com/coqui-ai/TTS.git
Enable ensuring minimum length per token
This commit is contained in:
parent
8a29d57ff0
commit
687789558e
|
@ -1082,12 +1082,47 @@ class Vits(BaseTTS):
|
|||
return aux_input["x_lengths"]
|
||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
||||
# JMa: set minimum duration if predicted duration is lower than threshold
|
||||
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
|
||||
# @staticmethod
|
||||
# def _set_min_inference_length(d, threshold):
|
||||
# d_mask = d < threshold
|
||||
# d[d_mask] = threshold
|
||||
# return d
|
||||
|
||||
def _set_min_inference_length(self, x, durs, threshold):
|
||||
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
|
||||
# Get list of tokens from IDs
|
||||
tokens = x.squeeze().tolist()
|
||||
# Check current and next token
|
||||
n = self.tokenizer.characters.id_to_char(tokens[0])
|
||||
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
|
||||
for ix, idx in enumerate(tokens[1:]):
|
||||
# c = self.tokenizer.characters.id_to_char(id_c)
|
||||
c = n
|
||||
n = self.tokenizer.characters.id_to_char(idx)
|
||||
if c in punctlike:
|
||||
# Skip thresholding for punctuation
|
||||
continue
|
||||
# Add duration from next punctuation if possible
|
||||
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
|
||||
# Threshold duration if duration lower than threshold
|
||||
if d < threshold:
|
||||
durs[:,:,ix] = threshold
|
||||
return durs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
|
||||
aux_input={"x_lengths": None,
|
||||
"d_vectors": None,
|
||||
"speaker_ids": None,
|
||||
"language_ids": None,
|
||||
"durations": None,
|
||||
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
||||
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
|
||||
},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1107,6 +1142,9 @@ class Vits(BaseTTS):
|
|||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
# JMa: Save input
|
||||
x_input = x
|
||||
|
||||
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
|
@ -1135,9 +1173,13 @@ class Vits(BaseTTS):
|
|||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
# JMa: set minimum duration if required
|
||||
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||
# JMa: length scale for the given sentence-like input
|
||||
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
||||
w = torch.exp(logw) * x_mask * length_scale
|
||||
w *= length_scale
|
||||
# w = torch.exp(logw) * x_mask * length_scale
|
||||
else:
|
||||
assert durations.shape[-1] == x.shape[-1]
|
||||
w = durations.unsqueeze(0)
|
||||
|
|
|
@ -30,7 +30,7 @@ def run_model_torch(
|
|||
style_text: str = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
aux_input: Dict = {},
|
||||
aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0},
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -49,22 +49,19 @@ def run_model_torch(
|
|||
_func = model.module.inference
|
||||
else:
|
||||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
aux_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
"style_text": style_text,
|
||||
"language_ids": language_id,
|
||||
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
|
||||
# - `length_scale` changes length of the whole generated wav
|
||||
# - `durations` sets up duration (in frames) for each input text ID
|
||||
"durations": aux_input.get("durations", None),
|
||||
"length_scale": aux_input.get("length_scale", None),
|
||||
},
|
||||
)
|
||||
# JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input`
|
||||
# to enable changing length (durations) per each input text (sentence) and to set
|
||||
# minimum allowed length of each input char/phoneme
|
||||
# - `length_scale` changes length of the whole generated wav
|
||||
# - `durations` sets up duration (in frames) for each input text ID
|
||||
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
|
||||
aux_input["x_lengths"] = input_lengths
|
||||
aux_input["speaker_ids"] = speaker_id
|
||||
aux_input["d_vectors"] = d_vector
|
||||
aux_input["style_mel"] = style_mel
|
||||
aux_input["style_text"] = style_text
|
||||
aux_input["language_ids"] = language_id
|
||||
outputs = _func(inputs, aux_input)
|
||||
return outputs
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue