mirror of https://github.com/coqui-ai/TTS.git
Enable ensuring minimum length per token
This commit is contained in:
parent
8a29d57ff0
commit
687789558e
|
@ -1082,12 +1082,47 @@ class Vits(BaseTTS):
|
||||||
return aux_input["x_lengths"]
|
return aux_input["x_lengths"]
|
||||||
return torch.tensor(x.shape[1:2]).to(x.device)
|
return torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
|
||||||
# JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
# JMa: set minimum duration if predicted duration is lower than threshold
|
||||||
|
# Workaround to avoid short durations that cause some chars/phonemes to be reduced
|
||||||
|
# @staticmethod
|
||||||
|
# def _set_min_inference_length(d, threshold):
|
||||||
|
# d_mask = d < threshold
|
||||||
|
# d[d_mask] = threshold
|
||||||
|
# return d
|
||||||
|
|
||||||
|
def _set_min_inference_length(self, x, durs, threshold):
|
||||||
|
punctlike = list(self.config.characters.punctuations) + [self.config.characters.blank]
|
||||||
|
# Get list of tokens from IDs
|
||||||
|
tokens = x.squeeze().tolist()
|
||||||
|
# Check current and next token
|
||||||
|
n = self.tokenizer.characters.id_to_char(tokens[0])
|
||||||
|
# for ix, (c, n) in enumerate(zip(tokens[:-1], tokens[1:])):
|
||||||
|
for ix, idx in enumerate(tokens[1:]):
|
||||||
|
# c = self.tokenizer.characters.id_to_char(id_c)
|
||||||
|
c = n
|
||||||
|
n = self.tokenizer.characters.id_to_char(idx)
|
||||||
|
if c in punctlike:
|
||||||
|
# Skip thresholding for punctuation
|
||||||
|
continue
|
||||||
|
# Add duration from next punctuation if possible
|
||||||
|
d = durs[:,:,ix] + durs[:,:,ix+1] if n in punctlike else durs[:,:,ix]
|
||||||
|
# Threshold duration if duration lower than threshold
|
||||||
|
if d < threshold:
|
||||||
|
durs[:,:,ix] = threshold
|
||||||
|
return durs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None, "length_scale": None},
|
aux_input={"x_lengths": None,
|
||||||
|
"d_vectors": None,
|
||||||
|
"speaker_ids": None,
|
||||||
|
"language_ids": None,
|
||||||
|
"durations": None,
|
||||||
|
"length_scale": None, # JMa: add `length_scale`` to `aux_input` to enable changing length (duration) per each input text (sentence)
|
||||||
|
"min_input_length": 0 # JMa: set minimum length if predicted length is lower than `min_input_length`
|
||||||
|
},
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Note:
|
Note:
|
||||||
|
@ -1107,6 +1142,9 @@ class Vits(BaseTTS):
|
||||||
- m_p: :math:`[B, C, T_dec]`
|
- m_p: :math:`[B, C, T_dec]`
|
||||||
- logs_p: :math:`[B, C, T_dec]`
|
- logs_p: :math:`[B, C, T_dec]`
|
||||||
"""
|
"""
|
||||||
|
# JMa: Save input
|
||||||
|
x_input = x
|
||||||
|
|
||||||
sid, g, lid, durations = self._set_cond_input(aux_input)
|
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||||
x_lengths = self._set_x_lengths(x, aux_input)
|
x_lengths = self._set_x_lengths(x, aux_input)
|
||||||
|
|
||||||
|
@ -1135,9 +1173,13 @@ class Vits(BaseTTS):
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||||
)
|
)
|
||||||
|
# JMa: set minimum duration if required
|
||||||
|
# w = self._set_min_inference_length(torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||||
|
w = self._set_min_inference_length(x_input, torch.exp(logw) * x_mask, aux_input["min_input_length"]) if aux_input["min_input_length"] else torch.exp(logw) * x_mask
|
||||||
# JMa: length scale for the given sentence-like input
|
# JMa: length scale for the given sentence-like input
|
||||||
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
length_scale = aux_input["length_scale"] if aux_input["length_scale"] else self.length_scale
|
||||||
w = torch.exp(logw) * x_mask * length_scale
|
w *= length_scale
|
||||||
|
# w = torch.exp(logw) * x_mask * length_scale
|
||||||
else:
|
else:
|
||||||
assert durations.shape[-1] == x.shape[-1]
|
assert durations.shape[-1] == x.shape[-1]
|
||||||
w = durations.unsqueeze(0)
|
w = durations.unsqueeze(0)
|
||||||
|
|
|
@ -30,7 +30,7 @@ def run_model_torch(
|
||||||
style_text: str = None,
|
style_text: str = None,
|
||||||
d_vector: torch.Tensor = None,
|
d_vector: torch.Tensor = None,
|
||||||
language_id: torch.Tensor = None,
|
language_id: torch.Tensor = None,
|
||||||
aux_input: Dict = {},
|
aux_input: Dict = {"durations": None, "length_scale": None, "min_input_length": 0},
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Run a torch model for inference. It does not support batch inference.
|
"""Run a torch model for inference. It does not support batch inference.
|
||||||
|
|
||||||
|
@ -49,22 +49,19 @@ def run_model_torch(
|
||||||
_func = model.module.inference
|
_func = model.module.inference
|
||||||
else:
|
else:
|
||||||
_func = model.inference
|
_func = model.inference
|
||||||
outputs = _func(
|
# JMa: propagate `durations``, `length_scale``, and `min_input_length` to `aux_input`
|
||||||
inputs,
|
# to enable changing length (durations) per each input text (sentence) and to set
|
||||||
aux_input={
|
# minimum allowed length of each input char/phoneme
|
||||||
"x_lengths": input_lengths,
|
|
||||||
"speaker_ids": speaker_id,
|
|
||||||
"d_vectors": d_vector,
|
|
||||||
"style_mel": style_mel,
|
|
||||||
"style_text": style_text,
|
|
||||||
"language_ids": language_id,
|
|
||||||
# JMa: add `durations`` and `length_scale`` to `aux_input` to enable changing length (durations) per each input text (sentence)
|
|
||||||
# - `length_scale` changes length of the whole generated wav
|
# - `length_scale` changes length of the whole generated wav
|
||||||
# - `durations` sets up duration (in frames) for each input text ID
|
# - `durations` sets up duration (in frames) for each input text ID
|
||||||
"durations": aux_input.get("durations", None),
|
# - minimum allowed length (in frames) per input ID (char/phoneme) during inference
|
||||||
"length_scale": aux_input.get("length_scale", None),
|
aux_input["x_lengths"] = input_lengths
|
||||||
},
|
aux_input["speaker_ids"] = speaker_id
|
||||||
)
|
aux_input["d_vectors"] = d_vector
|
||||||
|
aux_input["style_mel"] = style_mel
|
||||||
|
aux_input["style_text"] = style_text
|
||||||
|
aux_input["language_ids"] = language_id
|
||||||
|
outputs = _func(inputs, aux_input)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue