mirror of https://github.com/coqui-ai/TTS.git
Add durations as aux input for VITS (#1694)
* Add durations as aux input for VITS * Make style * Fix tts_tests * Fix test_get_aux_input
This commit is contained in:
parent
2cf89b88c9
commit
c614f21982
|
@ -786,7 +786,7 @@ class Vits(BaseTTS):
|
|||
print(" > Text Encoder was reinit.")
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||
|
||||
def _freeze_layers(self):
|
||||
|
@ -817,7 +817,7 @@ class Vits(BaseTTS):
|
|||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g, lid = None, None, None
|
||||
sid, g, lid, durations = None, None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
|
@ -832,7 +832,10 @@ class Vits(BaseTTS):
|
|||
if lid.ndim == 0:
|
||||
lid = lid.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid
|
||||
if "durations" in aux_input and aux_input["durations"] is not None:
|
||||
durations = aux_input["durations"]
|
||||
|
||||
return sid, g, lid, durations
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
|
@ -946,7 +949,7 @@ class Vits(BaseTTS):
|
|||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -1028,7 +1031,9 @@ class Vits(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
||||
self,
|
||||
x,
|
||||
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1048,7 +1053,7 @@ class Vits(BaseTTS):
|
|||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
# speaker embedding
|
||||
|
@ -1062,21 +1067,25 @@ class Vits(BaseTTS):
|
|||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g if self.args.condition_dp_on_speaker else None,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb,
|
||||
)
|
||||
if durations is None:
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g if self.args.condition_dp_on_speaker else None,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb,
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
assert durations.shape[-1] == x.shape[-1]
|
||||
w = durations.unsqueeze(0)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||
|
|
Loading…
Reference in New Issue