mirror of https://github.com/coqui-ai/TTS.git
added speed as argument
This commit is contained in:
parent
dbf1a08a0d
commit
8a078be695
|
@ -168,9 +168,7 @@ class TTS(nn.Module):
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||||
model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# init synthesizer
|
# init synthesizer
|
||||||
# None values are fetch from the model
|
# None values are fetch from the model
|
||||||
|
@ -283,6 +281,7 @@ class TTS(nn.Module):
|
||||||
style_text=None,
|
style_text=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
split_sentences=split_sentences,
|
split_sentences=split_sentences,
|
||||||
|
speed=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return wav
|
return wav
|
||||||
|
@ -337,6 +336,7 @@ class TTS(nn.Module):
|
||||||
language=language,
|
language=language,
|
||||||
speaker_wav=speaker_wav,
|
speaker_wav=speaker_wav,
|
||||||
split_sentences=split_sentences,
|
split_sentences=split_sentences,
|
||||||
|
speed=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
|
||||||
|
|
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
return gpt_cond_latents, speaker_embedding
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, speed=1.0, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -410,13 +410,15 @@ class Xtts(BaseTTS):
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||||
settings.update({
|
settings.update(
|
||||||
"gpt_cond_len": config.gpt_cond_len,
|
{
|
||||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"max_ref_len": config.max_ref_len,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"sound_norm_refs": config.sound_norm_refs,
|
"max_ref_len": config.max_ref_len,
|
||||||
})
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
return self.full_inference(text, speaker_wav, language, **settings)
|
}
|
||||||
|
)
|
||||||
|
return self.full_inference(text, speaker_wav, language, speed, **settings)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def full_inference(
|
def full_inference(
|
||||||
|
@ -424,6 +426,7 @@ class Xtts(BaseTTS):
|
||||||
text,
|
text,
|
||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
language,
|
language,
|
||||||
|
speed,
|
||||||
# GPT inference
|
# GPT inference
|
||||||
temperature=0.75,
|
temperature=0.75,
|
||||||
length_penalty=1.0,
|
length_penalty=1.0,
|
||||||
|
@ -484,6 +487,7 @@ class Xtts(BaseTTS):
|
||||||
max_ref_length=max_ref_len,
|
max_ref_length=max_ref_len,
|
||||||
sound_norm_refs=sound_norm_refs,
|
sound_norm_refs=sound_norm_refs,
|
||||||
)
|
)
|
||||||
|
self.speed = speed
|
||||||
|
|
||||||
return self.inference(
|
return self.inference(
|
||||||
text,
|
text,
|
||||||
|
@ -518,6 +522,7 @@ class Xtts(BaseTTS):
|
||||||
enable_text_splitting=False,
|
enable_text_splitting=False,
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
speed = self.speed
|
||||||
language = language.split("-")[0] # remove the country code
|
language = language.split("-")[0] # remove the country code
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
||||||
|
@ -756,13 +761,11 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||||
|
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
||||||
if speaker_file_path is None and checkpoint_dir is not None:
|
|
||||||
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
|
||||||
|
|
||||||
self.language_manager = LanguageManager(config)
|
self.language_manager = LanguageManager(config)
|
||||||
self.speaker_manager = None
|
self.speaker_manager = None
|
||||||
if speaker_file_path is not None and os.path.exists(speaker_file_path):
|
if os.path.exists(speaker_file_path):
|
||||||
self.speaker_manager = SpeakerManager(speaker_file_path)
|
self.speaker_manager = SpeakerManager(speaker_file_path)
|
||||||
|
|
||||||
if os.path.exists(vocab_path):
|
if os.path.exists(vocab_path):
|
||||||
|
|
|
@ -265,6 +265,7 @@ class Synthesizer(nn.Module):
|
||||||
reference_wav=None,
|
reference_wav=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
split_sentences: bool = True,
|
split_sentences: bool = True,
|
||||||
|
speed=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
@ -391,6 +392,7 @@ class Synthesizer(nn.Module):
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
speaker_wav=speaker_wav,
|
speaker_wav=speaker_wav,
|
||||||
language=language_name,
|
language=language_name,
|
||||||
|
speed=1.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue