diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index b9c8ac78..5877e266 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -764,7 +764,7 @@ class Vits(BaseTTS): orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate ) # pylint: disable=W0201 - def on_init_end(self, trainer): # pylint: disable=W0613 + def on_init_end(self, trainer): # pylint: disable=W0613 """Reinit layes if needed""" if self.args.reinit_DP: before_dict = get_module_weights_sum(self.duration_predictor) @@ -785,7 +785,7 @@ class Vits(BaseTTS): if value == before_dict[key]: raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") print(" > Text Encoder was reinit.") - + def get_aux_input(self, aux_input: Dict): sid, g, lid = self._set_cond_input(aux_input) return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}