diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 8f510f20..365ab8bd 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -6,7 +6,7 @@ import pysbd import torch from TTS.config import load_config -from TTS.tts.models import setup_model +from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.utils.speakers import SpeakerManager # pylint: disable=unused-wildcard-import @@ -14,7 +14,8 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.text import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor -from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_generator +from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input class Synthesizer(object): @@ -98,7 +99,7 @@ class Synthesizer(object): self.speaker_manager = SpeakerManager( encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config ) - self.speaker_manager.load_d_vectors_file(self.tts_config.get("external_speaker_embedding_file", speaker_file)) + self.speaker_manager.load_d_vectors_file(self.tts_config.get("d_vector_file", speaker_file)) self.num_speakers = self.speaker_manager.num_speakers self.d_vector_dim = self.speaker_manager.d_vector_dim @@ -127,16 +128,11 @@ class Synthesizer(object): if self.tts_config.use_speaker_embedding is True: self.tts_speakers_file = ( - self.tts_speakers_file if self.tts_speakers_file else self.tts_config["external_speaker_embedding_file"] + self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"] ) - self._load_speakers(self.tts_speakers_file) + self.tts_config["d_vector_file"] = self.tts_speakers_file - self.tts_model = setup_model( - self.input_size, - num_speakers=self.num_speakers, - c=self.tts_config, - d_vector_dim=self.d_vector_dim, - ) + self.tts_model = setup_tts_model(config=self.tts_config) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -151,7 +147,7 @@ class Synthesizer(object): """ self.vocoder_config = load_config(model_config) self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) - self.vocoder_model = setup_generator(self.vocoder_config) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: self.vocoder_model.cuda()