mirror of https://github.com/coqui-ai/TTS.git
Add durations as aux input for VITS (#1694)
* Add durations as aux input for VITS * Make style * Fix tts_tests * Fix test_get_aux_input
This commit is contained in:
parent
2cf89b88c9
commit
c614f21982
|
@ -786,7 +786,7 @@ class Vits(BaseTTS):
|
||||||
print(" > Text Encoder was reinit.")
|
print(" > Text Encoder was reinit.")
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||||
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
||||||
|
|
||||||
def _freeze_layers(self):
|
def _freeze_layers(self):
|
||||||
|
@ -817,7 +817,7 @@ class Vits(BaseTTS):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
sid, g, lid = None, None, None
|
sid, g, lid, durations = None, None, None, None
|
||||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
sid = aux_input["speaker_ids"]
|
sid = aux_input["speaker_ids"]
|
||||||
if sid.ndim == 0:
|
if sid.ndim == 0:
|
||||||
|
@ -832,7 +832,10 @@ class Vits(BaseTTS):
|
||||||
if lid.ndim == 0:
|
if lid.ndim == 0:
|
||||||
lid = lid.unsqueeze_(0)
|
lid = lid.unsqueeze_(0)
|
||||||
|
|
||||||
return sid, g, lid
|
if "durations" in aux_input and aux_input["durations"] is not None:
|
||||||
|
durations = aux_input["durations"]
|
||||||
|
|
||||||
|
return sid, g, lid, durations
|
||||||
|
|
||||||
def _set_speaker_input(self, aux_input: Dict):
|
def _set_speaker_input(self, aux_input: Dict):
|
||||||
d_vectors = aux_input.get("d_vectors", None)
|
d_vectors = aux_input.get("d_vectors", None)
|
||||||
|
@ -946,7 +949,7 @@ class Vits(BaseTTS):
|
||||||
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||||
"""
|
"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.args.use_speaker_embedding and sid is not None:
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -1028,7 +1031,9 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(
|
def inference(
|
||||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
|
self,
|
||||||
|
x,
|
||||||
|
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
||||||
): # pylint: disable=dangerous-default-value
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Note:
|
Note:
|
||||||
|
@ -1048,7 +1053,7 @@ class Vits(BaseTTS):
|
||||||
- m_p: :math:`[B, C, T_dec]`
|
- m_p: :math:`[B, C, T_dec]`
|
||||||
- logs_p: :math:`[B, C, T_dec]`
|
- logs_p: :math:`[B, C, T_dec]`
|
||||||
"""
|
"""
|
||||||
sid, g, lid = self._set_cond_input(aux_input)
|
sid, g, lid, durations = self._set_cond_input(aux_input)
|
||||||
x_lengths = self._set_x_lengths(x, aux_input)
|
x_lengths = self._set_x_lengths(x, aux_input)
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
|
@ -1062,21 +1067,25 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if durations is None:
|
||||||
logw = self.duration_predictor(
|
if self.args.use_sdp:
|
||||||
x,
|
logw = self.duration_predictor(
|
||||||
x_mask,
|
x,
|
||||||
g=g if self.args.condition_dp_on_speaker else None,
|
x_mask,
|
||||||
reverse=True,
|
g=g if self.args.condition_dp_on_speaker else None,
|
||||||
noise_scale=self.inference_noise_scale_dp,
|
reverse=True,
|
||||||
lang_emb=lang_emb,
|
noise_scale=self.inference_noise_scale_dp,
|
||||||
)
|
lang_emb=lang_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logw = self.duration_predictor(
|
||||||
|
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||||
|
)
|
||||||
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(
|
assert durations.shape[-1] == x.shape[-1]
|
||||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
w = durations.unsqueeze(0)
|
||||||
)
|
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
||||||
|
|
Loading…
Reference in New Issue