From 175ca063884a145fbf2094b6a9dd339f7953b1f4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 9 May 2022 14:28:37 -0300 Subject: [PATCH] Add reinit text encoder and duration predictor parameter (#1562) * Add reinit encoder and duration predictor option * Add .data to prevent any overlooked autograd hook --- TTS/tts/models/vits.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 34e9fbcf..af995bda 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -41,6 +41,23 @@ hann_window = {} mel_basis = {} +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + def load_audio(file_path): """Load the audio file normalized in [-1, 1] @@ -528,6 +545,8 @@ class VitsArgs(Coqpit): freeze_waveform_decoder: bool = False encoder_sample_rate: int = None interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False class Vits(BaseTTS): @@ -744,6 +763,28 @@ class Vits(BaseTTS): orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate ) # pylint: disable=W0201 + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}