diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index ad6d95f7..8c364300 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model +from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model from TTS.speaker_encoder.utils.training import init_training from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples @@ -151,7 +151,7 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_eval ap = AudioProcessor(**c.audio) - model = setup_model(c) + model = setup_speaker_encoder_model(c) optimizer = RAdam(model.parameters(), lr=c.lr) diff --git a/TTS/server/server.py b/TTS/server/server.py index 2c6bebfd..f2512582 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -100,7 +100,15 @@ if args.vocoder_path is not None: # load models synthesizer = Synthesizer( - model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint="", + encoder_config="", + use_cuda=args.use_cuda, ) use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1 diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 905f50d7..db54027d 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -170,7 +170,7 @@ class Synthesizer(object): def _init_speaker_encoder(self, speaker_manager): """Initialize the SpeakerEncoder""" - if self.encoder_checkpoint is not None: + if self.encoder_checkpoint: speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) return speaker_manager diff --git a/tests/aux_tests/test_speaker_manager.py b/tests/aux_tests/test_speaker_manager.py index baa50749..b56c5258 100644 --- a/tests/aux_tests/test_speaker_manager.py +++ b/tests/aux_tests/test_speaker_manager.py @@ -6,7 +6,7 @@ import torch from tests import get_tests_input_path from TTS.config import load_config -from TTS.speaker_encoder.utils.generic_utils import setup_model +from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.speaker_encoder.utils.io import save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.audio import AudioProcessor @@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase): config.audio.resample = True # create a dummy speaker encoder - model = setup_model(config) + model = setup_speaker_encoder_model(config) save_checkpoint(model, None, None, get_tests_input_path(), 0) # load audio processor and speaker encoder