From 54c6bb2a8cce664ce2492ecb946c61f830e78965 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 3 Feb 2022 15:37:13 +0100 Subject: [PATCH] Fix add speaker VITS --- TTS/tts/models/vits.py | 53 ++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index a69b02ba..f02090cf 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -799,21 +799,7 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: - """Perform a single training step. Run the model forward pass and compute losses. - - Args: - batch (Dict): Input tensors. - criterion (nn.Module): Loss layer designed for the model. - optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. - - Returns: - Tuple[Dict, Dict]: Model ouputs and computed losses. - """ - # pylint: disable=attribute-defined-outside-init - if optimizer_idx not in [0, 1]: - raise ValueError(" [!] Unexpected `optimizer_idx`.") - + def _freeze_layers(self): if self.args.freeze_encoder: for param in self.text_encoder.parameters(): param.requires_grad = False @@ -838,6 +824,24 @@ class Vits(BaseTTS): for param in self.waveform_decoder.parameters(): param.requires_grad = False + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + # pylint: disable=attribute-defined-outside-init + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + self._freeze_layers() + if optimizer_idx == 0: text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -848,6 +852,9 @@ class Vits(BaseTTS): language_ids = batch["language_ids"] waveform = batch["waveform"] + # if (waveform > 1).sum() > 0 or (waveform < -1).sum() > 0: + # breakpoint() + # generator pass outputs = self.forward( text_input, @@ -859,8 +866,6 @@ class Vits(BaseTTS): ) # cache tensors for the discriminator - self.y_disc_cache = None - self.wav_seg_disc_cache = None self.y_disc_cache = outputs["model_outputs"] self.wav_seg_disc_cache = outputs["waveform_seg"] @@ -888,6 +893,9 @@ class Vits(BaseTTS): syn_spk_emb=outputs["syn_spk_emb"], ) + # if loss_dict["loss_feat"].isnan().sum() > 0 or loss_dict["loss_feat"].isinf().sum() > 0: + # breakpoint() + elif optimizer_idx == 1: # discriminator pass outputs = {} @@ -984,7 +992,11 @@ class Vits(BaseTTS): test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) - return test_figures, test_audios + return {"figures": test_figures, "audios": test_audios} + + def test_log(self, outputs: dict, logger: "Logger", assets: dict, steps:int) -> None: + logger.test_audios(steps, outputs['audios'], self.ap.sample_rate) + logger.test_figures(steps, outputs['figures']) def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. @@ -1056,9 +1068,10 @@ class Vits(BaseTTS): state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} # handle fine-tuning from a checkpoint with additional speakers if state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: - print(" > Loading checkpoint with additional speakers.") + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") emb_g = state["model"]["emb_g.weight"] - new_row = torch.zeros(1, emb_g.shape[1]) + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) emb_g = torch.cat([emb_g, new_row], axis=0) state["model"]["emb_g.weight"] = emb_g