added speed as argument

This commit is contained in:
Moctar 2024-04-07 18:57:53 +02:00
parent dbf1a08a0d
commit 8a078be695
3 changed files with 22 additions and 17 deletions

View File

@ -168,9 +168,7 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name
)
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
# init synthesizer
# None values are fetch from the model
@ -283,6 +281,7 @@ class TTS(nn.Module):
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
speed=1.0,
**kwargs,
)
return wav
@ -337,6 +336,7 @@ class TTS(nn.Module):
language=language,
speaker_wav=speaker_wav,
split_sentences=split_sentences,
speed=1.0,
**kwargs,
)
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)

View File

@ -379,7 +379,7 @@ class Xtts(BaseTTS):
return gpt_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed=1.0, **kwargs):
"""Synthesize speech with the given input text.
Args:
@ -410,13 +410,15 @@ class Xtts(BaseTTS):
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
})
return self.full_inference(text, speaker_wav, language, **settings)
settings.update(
{
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
)
return self.full_inference(text, speaker_wav, language, speed, **settings)
@torch.inference_mode()
def full_inference(
@ -424,6 +426,7 @@ class Xtts(BaseTTS):
text,
ref_audio_path,
language,
speed,
# GPT inference
temperature=0.75,
length_penalty=1.0,
@ -484,6 +487,7 @@ class Xtts(BaseTTS):
max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs,
)
self.speed = speed
return self.inference(
text,
@ -518,6 +522,7 @@ class Xtts(BaseTTS):
enable_text_splitting=False,
**hf_generate_kwargs,
):
speed = self.speed
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
gpt_cond_latent = gpt_cond_latent.to(self.device)
@ -756,13 +761,11 @@ class Xtts(BaseTTS):
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
if speaker_file_path is None and checkpoint_dir is not None:
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth")
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
self.language_manager = LanguageManager(config)
self.speaker_manager = None
if speaker_file_path is not None and os.path.exists(speaker_file_path):
if os.path.exists(speaker_file_path):
self.speaker_manager = SpeakerManager(speaker_file_path)
if os.path.exists(vocab_path):

View File

@ -265,6 +265,7 @@ class Synthesizer(nn.Module):
reference_wav=None,
reference_speaker_name=None,
split_sentences: bool = True,
speed=1.0,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -391,6 +392,7 @@ class Synthesizer(nn.Module):
d_vector=speaker_embedding,
speaker_wav=speaker_wav,
language=language_name,
speed=1.0,
**kwargs,
)
else: