diff --git a/TTS/api.py b/TTS/api.py index 7abc188e..c58ab76f 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -168,9 +168,7 @@ class TTS(nn.Module): self.synthesizer = None self.model_name = model_name - model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( - model_name - ) + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) # init synthesizer # None values are fetch from the model @@ -283,6 +281,7 @@ class TTS(nn.Module): style_text=None, reference_speaker_name=None, split_sentences=split_sentences, + speed=1.0, **kwargs, ) return wav @@ -337,6 +336,7 @@ class TTS(nn.Module): language=language, speaker_wav=speaker_wav, split_sentences=split_sentences, + speed=1.0, **kwargs, ) self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd3..7ca6ff85 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -274,7 +274,7 @@ class Xtts(BaseTTS): for i in range(0, audio.shape[1], 22050 * chunk_length): audio_chunk = audio[:, i : i + 22050 * chunk_length] - # if the chunk is too short ignore it + # if the chunk is too short ignore it if audio_chunk.size(-1) < 22050 * 0.33: continue @@ -379,7 +379,7 @@ class Xtts(BaseTTS): return gpt_cond_latents, speaker_embedding - def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): + def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed=1.0, **kwargs): """Synthesize speech with the given input text. Args: @@ -410,13 +410,15 @@ class Xtts(BaseTTS): if speaker_id is not None: gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) - settings.update({ - "gpt_cond_len": config.gpt_cond_len, - "gpt_cond_chunk_len": config.gpt_cond_chunk_len, - "max_ref_len": config.max_ref_len, - "sound_norm_refs": config.sound_norm_refs, - }) - return self.full_inference(text, speaker_wav, language, **settings) + settings.update( + { + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + } + ) + return self.full_inference(text, speaker_wav, language, speed, **settings) @torch.inference_mode() def full_inference( @@ -424,6 +426,7 @@ class Xtts(BaseTTS): text, ref_audio_path, language, + speed, # GPT inference temperature=0.75, length_penalty=1.0, @@ -484,6 +487,7 @@ class Xtts(BaseTTS): max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs, ) + self.speed = speed return self.inference( text, @@ -518,6 +522,7 @@ class Xtts(BaseTTS): enable_text_splitting=False, **hf_generate_kwargs, ): + speed = self.speed language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) gpt_cond_latent = gpt_cond_latent.to(self.device) @@ -756,13 +761,11 @@ class Xtts(BaseTTS): model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") - - if speaker_file_path is None and checkpoint_dir is not None: - speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") + speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth") self.language_manager = LanguageManager(config) self.speaker_manager = None - if speaker_file_path is not None and os.path.exists(speaker_file_path): + if os.path.exists(speaker_file_path): self.speaker_manager = SpeakerManager(speaker_file_path) if os.path.exists(vocab_path): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index b98647c3..824b69b5 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -265,6 +265,7 @@ class Synthesizer(nn.Module): reference_wav=None, reference_speaker_name=None, split_sentences: bool = True, + speed=1.0, **kwargs, ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. @@ -335,7 +336,7 @@ class Synthesizer(nn.Module): # handle multi-lingual language_id = None if self.tts_languages_file or ( - hasattr(self.tts_model, "language_manager") + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None and not self.tts_config.model == "xtts" ): @@ -391,6 +392,7 @@ class Synthesizer(nn.Module): d_vector=speaker_embedding, speaker_wav=speaker_wav, language=language_name, + speed=1.0, **kwargs, ) else: