diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 966f7c0f..12ed7742 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence -from TTS.tts.layers.xtts.speaker_manager import SpeakerManager +from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -379,7 +379,7 @@ class Xtts(BaseTTS): return gpt_cond_latents, speaker_embedding - def synthesize(self, text, config, speaker_wav, language, **kwargs): + def synthesize(self, text, config, speaker_wav, language, speaker_id, **kwargs): """Synthesize speech with the given input text. Args: @@ -394,12 +394,6 @@ class Xtts(BaseTTS): `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` as latents used at inference. - """ - return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) - - def inference_with_config(self, text, config, ref_audio_path, language, **kwargs): - """ - inference with config """ assert ( "zh-cn" if language == "zh" else language in self.config.languages @@ -411,13 +405,18 @@ class Xtts(BaseTTS): "repetition_penalty": config.repetition_penalty, "top_k": config.top_k, "top_p": config.top_p, + } + settings.update(kwargs) # allow overriding of preset settings with kwargs + if speaker_id is not None: + gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() + return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) + settings.update({ "gpt_cond_len": config.gpt_cond_len, "gpt_cond_chunk_len": config.gpt_cond_chunk_len, "max_ref_len": config.max_ref_len, "sound_norm_refs": config.sound_norm_refs, - } - settings.update(kwargs) # allow overriding of preset settings with kwargs - return self.full_inference(text, ref_audio_path, language, **settings) + }) + return self.full_inference(text, speaker_wav, language, **settings) @torch.inference_mode() def full_inference( @@ -753,8 +752,9 @@ class Xtts(BaseTTS): model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") - speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers.json") + speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth") + self.language_manager = LanguageManager(config) self.speaker_manager = None if os.path.exists(speaker_file_path): self.speaker_manager = SpeakerManager(speaker_file_path) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 781561f9..b98647c3 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -305,7 +305,7 @@ class Synthesizer(nn.Module): speaker_embedding = None speaker_id = None if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): - if speaker_name and isinstance(speaker_name, str): + if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts": if self.tts_config.use_d_vector_file: # get the average speaker embedding from the saved d_vectors. speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( @@ -335,7 +335,9 @@ class Synthesizer(nn.Module): # handle multi-lingual language_id = None if self.tts_languages_file or ( - hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None + hasattr(self.tts_model, "language_manager") + and self.tts_model.language_manager is not None + and not self.tts_config.model == "xtts" ): if len(self.tts_model.language_manager.name_to_id) == 1: language_id = list(self.tts_model.language_manager.name_to_id.values())[0] @@ -366,6 +368,7 @@ class Synthesizer(nn.Module): if ( speaker_wav is not None and self.tts_model.speaker_manager is not None + and hasattr(self.tts_model.speaker_manager, "encoder_ap") and self.tts_model.speaker_manager.encoder_ap is not None ): speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)