mirror of https://github.com/coqui-ai/TTS.git
added speed as argument
This commit is contained in:
parent
dbf1a08a0d
commit
8a078be695
|
@ -168,9 +168,7 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name
|
||||
)
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
@ -283,6 +281,7 @@ class TTS(nn.Module):
|
|||
style_text=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences=split_sentences,
|
||||
speed=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
return wav
|
||||
|
@ -337,6 +336,7 @@ class TTS(nn.Module):
|
|||
language=language,
|
||||
speaker_wav=speaker_wav,
|
||||
split_sentences=split_sentences,
|
||||
speed=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||
|
|
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
return gpt_cond_latents, speaker_embedding
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
|
||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed=1.0, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
||||
Args:
|
||||
|
@ -410,13 +410,15 @@ class Xtts(BaseTTS):
|
|||
if speaker_id is not None:
|
||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||
settings.update({
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
})
|
||||
return self.full_inference(text, speaker_wav, language, **settings)
|
||||
settings.update(
|
||||
{
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
)
|
||||
return self.full_inference(text, speaker_wav, language, speed, **settings)
|
||||
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
|
@ -424,6 +426,7 @@ class Xtts(BaseTTS):
|
|||
text,
|
||||
ref_audio_path,
|
||||
language,
|
||||
speed,
|
||||
# GPT inference
|
||||
temperature=0.75,
|
||||
length_penalty=1.0,
|
||||
|
@ -484,6 +487,7 @@ class Xtts(BaseTTS):
|
|||
max_ref_length=max_ref_len,
|
||||
sound_norm_refs=sound_norm_refs,
|
||||
)
|
||||
self.speed = speed
|
||||
|
||||
return self.inference(
|
||||
text,
|
||||
|
@ -518,6 +522,7 @@ class Xtts(BaseTTS):
|
|||
enable_text_splitting=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
speed = self.speed
|
||||
language = language.split("-")[0] # remove the country code
|
||||
length_scale = 1.0 / max(speed, 0.05)
|
||||
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
||||
|
@ -756,13 +761,11 @@ class Xtts(BaseTTS):
|
|||
|
||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
|
||||
if speaker_file_path is None and checkpoint_dir is not None:
|
||||
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
||||
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
||||
|
||||
self.language_manager = LanguageManager(config)
|
||||
self.speaker_manager = None
|
||||
if speaker_file_path is not None and os.path.exists(speaker_file_path):
|
||||
if os.path.exists(speaker_file_path):
|
||||
self.speaker_manager = SpeakerManager(speaker_file_path)
|
||||
|
||||
if os.path.exists(vocab_path):
|
||||
|
|
|
@ -265,6 +265,7 @@ class Synthesizer(nn.Module):
|
|||
reference_wav=None,
|
||||
reference_speaker_name=None,
|
||||
split_sentences: bool = True,
|
||||
speed=1.0,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
@ -391,6 +392,7 @@ class Synthesizer(nn.Module):
|
|||
d_vector=speaker_embedding,
|
||||
speaker_wav=speaker_wav,
|
||||
language=language_name,
|
||||
speed=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue