From 2829027d8b4b72233d1948525b7fcdacd1fa23e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 21 Jan 2022 15:33:15 +0000 Subject: [PATCH] Refactor VITS model --- TTS/tts/models/vits.py | 109 +++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 41 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4612c02b..222bbca5 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -38,7 +38,7 @@ class VitsArgs(Coqpit): Number of characters in the vocabulary. Defaults to 100. out_channels (int): - Number of output channels. Defaults to 513. + Number of output channels of the decoder. Defaults to 513. spec_segment_size (int): Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. @@ -363,6 +363,8 @@ class Vits(BaseTTS): language_emb_dim=self.embedded_language_dim, ) + upsample_rate = math.prod(self.args.upsample_rates_decoder) + assert upsample_rate == self.config.audio.hop_length, f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {self.config.audio.hop_length}" self.waveform_decoder = HifiganGenerator( self.args.hidden_channels, 1, @@ -531,6 +533,54 @@ class Vits(BaseTTS): "language_name": language_name, } + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + def forward( self, x: torch.tensor, @@ -596,54 +646,27 @@ class Vits(BaseTTS): # flow layers z_p = self.flow(z, y_mask, g=g) - # find the alignment path - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - with torch.no_grad(): - o_scale = torch.exp(-2 * logs_p) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) - logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) - logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - # duration predictor - attn_durations = attn.sum(3) - g_dp = None - if self.args.condition_dp_on_speaker: - g_dp = g.detach() if self.args.detach_dp_input and g is not None else g - if self.args.use_sdp: - loss_duration = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - attn_durations, - g=g_dp, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = loss_duration / torch.sum(x_mask) - else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask - log_durations = self.duration_predictor( - x.detach() if self.args.detach_dp_input else x, - x_mask, - g=g_dp, - lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, - ) - loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration + if self.args.use_mas: + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + elif self.args.use_aligner_network: + outputs, attn = self.forward_aligner(outputs, m_p, z_p, x_mask, y_mask, g=g, lang_emb=lang_emb) + outputs["x_lens"] = x_lengths + outputs["y_lens"] = y_lengths # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) # select a random feature segment for the waveform decoder - z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size) + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) o = self.waveform_decoder(z_slice, g=g) wav_seg = segment( waveform, slice_ids * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length, + pad_short=True ) if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None: @@ -665,11 +688,11 @@ class Vits(BaseTTS): outputs.update( { "model_outputs": o, - "alignments": attn.squeeze(1), - "z": z, - "z_p": z_p, + "alignments" : attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, + "z": z, + "z_p": z_p, "m_q": m_q, "logs_q": logs_q, "waveform_seg": wav_seg, @@ -919,14 +942,18 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - self._log(self.ap, batch, outputs, "train") + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - return self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: