From 95ca2ef77356cc17793d02963760d50552aa87b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 16 Dec 2021 14:57:24 +0000 Subject: [PATCH] Implement init_speaker_encoder and change arg names --- TTS/utils/synthesizer.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index bd90dd8c..62540ae2 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,5 +1,5 @@ import time -from typing import List +from typing import List, Union import numpy as np import pysbd @@ -117,6 +117,7 @@ class Synthesizer(object): speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() + speaker_manager = self._init_speaker_encoder(speaker_manager) if language_manager is not None: self.tts_model = setup_tts_model( @@ -130,23 +131,47 @@ class Synthesizer(object): if use_cuda: self.tts_model.cuda() + def _is_use_speaker_embedding(self): + """Check if the speaker embedding is used in the model""" + # some models use model_args some don't + if hasattr(self.tts_config, "model_args"): + config = self.tts_config.model_args + else: + config = self.tts_config + return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True + + def _is_use_d_vector_file(self): + """Check if the d-vector file is used in the model""" + # some models use model_args some don't + if hasattr(self.tts_config, "model_args"): + config = self.tts_config.model_args + else: + config = self.tts_config + return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True + def _init_speaker_manager(self): """Initialize the SpeakerManager""" # setup if multi-speaker settings are in the global model config speaker_manager = None - if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True: + if self._is_use_speaker_embedding(): if self.tts_speakers_file: speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) if self.tts_config.get("speakers_file", None): speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) - if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True: + if self._is_use_d_vector_file(): if self.tts_speakers_file: speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) if self.tts_config.get("d_vector_file", None): speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) return speaker_manager + def _init_speaker_encoder(self, speaker_manager): + """Initialize the SpeakerEncoder""" + if self.encoder_checkpoint is not None: + speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config) + return speaker_manager + def _init_language_manager(self): """Initialize the LanguageManager""" # setup if multi-lingual settings are in the global model config @@ -203,7 +228,12 @@ class Synthesizer(object): self.ap.save_wav(wav, path, self.output_sample_rate) def tts( - self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None + self, + text: str, + speaker_name: str = "", + language_name: str = "", + speaker_wav: Union[str, List[str]] = None, + style_wav=None, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -211,7 +241,7 @@ class Synthesizer(object): text (str): input text. speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". language_name (str, optional): language id for multi-language models. Defaults to "". - speaker_wav (): + speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None. style_wav ([type], optional): style waveform for GST. Defaults to None. Returns: