From 001da8afc8285bb8a936d331d8f925a3b16f1641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:38:29 +0000 Subject: [PATCH] Update Vits for the new model API --- TTS/tts/models/vits.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 30dc7ec4..b5551268 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -278,7 +278,12 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, language_manager: LanguageManager = None + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, ): super().__init__(config, ap, tokenizer, speaker_manager) @@ -287,8 +292,6 @@ class Vits(BaseTTS): self.speaker_manager = speaker_manager self.language_manager = language_manager - self.args = args - self.init_multispeaker(config) self.init_multilingual(config) @@ -309,6 +312,7 @@ class Vits(BaseTTS): self.args.num_layers_text_encoder, self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( @@ -884,7 +888,7 @@ class Vits(BaseTTS): return self._log(self.ap, batch, outputs, "eval") @torch.no_grad() - def test_run(self) -> Tuple[Dict, Dict]: + def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -904,7 +908,7 @@ class Vits(BaseTTS): aux_inputs["text"], self.config, "cuda" in str(next(self.parameters()).device), - ap, + self.ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], @@ -1007,7 +1011,8 @@ class Vits(BaseTTS): ap = AudioProcessor.init_from_config(config) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) - return Vits(new_config, ap, tokenizer, speaker_manager) + language_manager = LanguageManager.init_from_config(config) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) class VitsCharacters(BaseCharacters):