mirror of https://github.com/coqui-ai/TTS.git
Make CLI work
This commit is contained in:
parent
0a90359a42
commit
e3c9dab7a3
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||||
from TTS.tts.layers.xtts.speaker_manager import SpeakerManager
|
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
@ -379,7 +379,7 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
return gpt_cond_latents, speaker_embedding
|
return gpt_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, speaker_id, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -394,12 +394,6 @@ class Xtts(BaseTTS):
|
||||||
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
|
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
|
||||||
as latents used at inference.
|
as latents used at inference.
|
||||||
|
|
||||||
"""
|
|
||||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
|
||||||
|
|
||||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
|
||||||
"""
|
|
||||||
inference with config
|
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
"zh-cn" if language == "zh" else language in self.config.languages
|
"zh-cn" if language == "zh" else language in self.config.languages
|
||||||
|
@ -411,13 +405,18 @@ class Xtts(BaseTTS):
|
||||||
"repetition_penalty": config.repetition_penalty,
|
"repetition_penalty": config.repetition_penalty,
|
||||||
"top_k": config.top_k,
|
"top_k": config.top_k,
|
||||||
"top_p": config.top_p,
|
"top_p": config.top_p,
|
||||||
|
}
|
||||||
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
|
if speaker_id is not None:
|
||||||
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||||
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||||
|
settings.update({
|
||||||
"gpt_cond_len": config.gpt_cond_len,
|
"gpt_cond_len": config.gpt_cond_len,
|
||||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||||
"max_ref_len": config.max_ref_len,
|
"max_ref_len": config.max_ref_len,
|
||||||
"sound_norm_refs": config.sound_norm_refs,
|
"sound_norm_refs": config.sound_norm_refs,
|
||||||
}
|
})
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
return self.full_inference(text, speaker_wav, language, **settings)
|
||||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def full_inference(
|
def full_inference(
|
||||||
|
@ -753,8 +752,9 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||||
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers.json")
|
speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
||||||
|
|
||||||
|
self.language_manager = LanguageManager(config)
|
||||||
self.speaker_manager = None
|
self.speaker_manager = None
|
||||||
if os.path.exists(speaker_file_path):
|
if os.path.exists(speaker_file_path):
|
||||||
self.speaker_manager = SpeakerManager(speaker_file_path)
|
self.speaker_manager = SpeakerManager(speaker_file_path)
|
||||||
|
|
|
@ -305,7 +305,7 @@ class Synthesizer(nn.Module):
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
speaker_id = None
|
speaker_id = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
|
||||||
if speaker_name and isinstance(speaker_name, str):
|
if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts":
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the average speaker embedding from the saved d_vectors.
|
# get the average speaker embedding from the saved d_vectors.
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||||
|
@ -335,7 +335,9 @@ class Synthesizer(nn.Module):
|
||||||
# handle multi-lingual
|
# handle multi-lingual
|
||||||
language_id = None
|
language_id = None
|
||||||
if self.tts_languages_file or (
|
if self.tts_languages_file or (
|
||||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
hasattr(self.tts_model, "language_manager")
|
||||||
|
and self.tts_model.language_manager is not None
|
||||||
|
and not self.tts_config.model == "xtts"
|
||||||
):
|
):
|
||||||
if len(self.tts_model.language_manager.name_to_id) == 1:
|
if len(self.tts_model.language_manager.name_to_id) == 1:
|
||||||
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
|
language_id = list(self.tts_model.language_manager.name_to_id.values())[0]
|
||||||
|
@ -366,6 +368,7 @@ class Synthesizer(nn.Module):
|
||||||
if (
|
if (
|
||||||
speaker_wav is not None
|
speaker_wav is not None
|
||||||
and self.tts_model.speaker_manager is not None
|
and self.tts_model.speaker_manager is not None
|
||||||
|
and hasattr(self.tts_model.speaker_manager, "encoder_ap")
|
||||||
and self.tts_model.speaker_manager.encoder_ap is not None
|
and self.tts_model.speaker_manager.encoder_ap is not None
|
||||||
):
|
):
|
||||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||||
|
|
Loading…
Reference in New Issue