From 760f045aaa179e6cc3feae41ce4188df3e50be6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 4 Apr 2022 09:45:46 +0200 Subject: [PATCH] Rename vars in VITS --- TTS/tts/models/vits.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2c1c2bc6..cb83f7ca 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -29,6 +29,7 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.generic_utils import count_parameters from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -125,6 +126,7 @@ def wav_to_spec(y, n_fft, hop_length, win_length, center=False): pad_mode="reflect", normalized=False, onesided=True, + return_complex=True, ) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) @@ -521,7 +523,7 @@ class VitsArgs(Coqpit): inference_noise_scale_dp: float = 1.0 max_inference_len: int = None init_discriminator: bool = True - use_spectral_norm_disriminator: bool = False + use_spectral_norm_discriminator: bool = False use_speaker_embedding: bool = False num_speakers: int = 0 speakers_file: str = None @@ -857,21 +859,21 @@ class Vits(BaseTTS): logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp2 + logp3 + logp1 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + mas_attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] # duration predictor - attn_durations = attn.sum(3) + mas_attn_durations = mas_attn.sum(3) if self.args.use_sdp: loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, - attn_durations, + mas_attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, ) loss_duration = loss_duration / torch.sum(x_mask) else: - attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + attn_log_durations = torch.log(mas_attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, @@ -880,7 +882,7 @@ class Vits(BaseTTS): ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) outputs["loss_duration"] = loss_duration - return outputs, attn + return outputs, mas_attn def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): spec_segment_size = self.spec_segment_size @@ -965,11 +967,11 @@ class Vits(BaseTTS): z_p = self.flow(z, y_mask, g=g) # duration predictor - outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + outputs, mas_attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) # expand prior - m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) - logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + m_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [mas_attn, logs_p]) # select a random feature segment for the waveform decoder z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) @@ -1005,7 +1007,7 @@ class Vits(BaseTTS): outputs.update( { "model_outputs": o, - "alignments": attn.squeeze(1), + "alignments": mas_attn.squeeze(1), "m_p": m_p, "logs_p": logs_p, "z": z, @@ -1269,7 +1271,8 @@ class Vits(BaseTTS): raise ValueError(" [!] Unexpected `optimizer_idx`.") - def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + @staticmethod + def _log(ap, outputs, name_prefix="train"): y_hat = outputs[1]["model_outputs"] y = outputs[1]["waveform_seg"] figures = plot_results(y_hat, y, ap, name_prefix) @@ -1302,7 +1305,7 @@ class Vits(BaseTTS): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - figures, audios = self._log(self.ap, batch, outputs, "train") + figures, audios = self._log(self.ap, outputs, "train") logger.train_figures(steps, figures) logger.train_audios(steps, audios, self.ap.sample_rate) @@ -1311,7 +1314,7 @@ class Vits(BaseTTS): return self.train_step(batch, criterion, optimizer_idx) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(self.ap, batch, outputs, "eval") + figures, audios = self._log(self.ap, outputs, "eval") logger.eval_figures(steps, figures) logger.eval_audios(steps, audios, self.ap.sample_rate) @@ -1542,9 +1545,7 @@ class Vits(BaseTTS): Returns: List: optimizers. """ - # select generator parameters optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) - gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) optimizer1 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters