Add durations as aux input for VITS (#1694)

* Add durations as aux input for VITS

* Make style

* Fix tts_tests

* Fix test_get_aux_input
This commit is contained in:
WeberJulian 2022-07-12 14:25:21 +02:00 committed by GitHub
parent 2cf89b88c9
commit c614f21982
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 19 deletions

View File

@ -786,7 +786,7 @@ class Vits(BaseTTS):
print(" > Text Encoder was reinit.")
def get_aux_input(self, aux_input: Dict):
sid, g, lid = self._set_cond_input(aux_input)
sid, g, lid, _ = self._set_cond_input(aux_input)
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
def _freeze_layers(self):
@ -817,7 +817,7 @@ class Vits(BaseTTS):
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid = None, None, None
sid, g, lid, durations = None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
@ -832,7 +832,10 @@ class Vits(BaseTTS):
if lid.ndim == 0:
lid = lid.unsqueeze_(0)
return sid, g, lid
if "durations" in aux_input and aux_input["durations"] is not None:
durations = aux_input["durations"]
return sid, g, lid, durations
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
@ -946,7 +949,7 @@ class Vits(BaseTTS):
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
"""
outputs = {}
sid, g, lid = self._set_cond_input(aux_input)
sid, g, lid, _ = self._set_cond_input(aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -1028,7 +1031,9 @@ class Vits(BaseTTS):
@torch.no_grad()
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
self,
x,
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1048,7 +1053,7 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
sid, g, lid = self._set_cond_input(aux_input)
sid, g, lid, durations = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding
@ -1062,21 +1067,25 @@ class Vits(BaseTTS):
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
if self.args.use_sdp:
logw = self.duration_predictor(
x,
x_mask,
g=g if self.args.condition_dp_on_speaker else None,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
)
if durations is None:
if self.args.use_sdp:
logw = self.duration_predictor(
x,
x_mask,
g=g if self.args.condition_dp_on_speaker else None,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
)
else:
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale
else:
logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
assert durations.shape[-1] == x.shape[-1]
w = durations.unsqueeze(0)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]