mirror of https://github.com/coqui-ai/TTS.git
refactor(bin.synthesize): use Python API for CLI
This commit is contained in:
parent
806af96e4c
commit
e0f621180f
16
TTS/api.py
16
TTS/api.py
|
@ -109,7 +109,11 @@ class TTS(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multi_speaker(self) -> bool:
|
def is_multi_speaker(self) -> bool:
|
||||||
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
|
if (
|
||||||
|
self.synthesizer is not None
|
||||||
|
and hasattr(self.synthesizer.tts_model, "speaker_manager")
|
||||||
|
and self.synthesizer.tts_model.speaker_manager
|
||||||
|
):
|
||||||
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -123,7 +127,11 @@ class TTS(nn.Module):
|
||||||
and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1)
|
and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1)
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
if (
|
||||||
|
self.synthesizer is not None
|
||||||
|
and hasattr(self.synthesizer.tts_model, "language_manager")
|
||||||
|
and self.synthesizer.tts_model.language_manager
|
||||||
|
):
|
||||||
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -306,10 +314,6 @@ class TTS(nn.Module):
|
||||||
speaker_name=speaker,
|
speaker_name=speaker,
|
||||||
language_name=language,
|
language_name=language,
|
||||||
speaker_wav=speaker_wav,
|
speaker_wav=speaker_wav,
|
||||||
reference_wav=None,
|
|
||||||
style_wav=None,
|
|
||||||
style_text=None,
|
|
||||||
reference_speaker_name=None,
|
|
||||||
split_sentences=split_sentences,
|
split_sentences=split_sentences,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,8 +9,6 @@ import sys
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name, unused-argument
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -312,7 +310,8 @@ def parse_args() -> argparse.Namespace:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
|
"""Entry point for `tts` command line interface."""
|
||||||
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
@ -320,12 +319,11 @@ def main():
|
||||||
|
|
||||||
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
|
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
|
||||||
# Late-import to make things load faster
|
# Late-import to make things load faster
|
||||||
|
from TTS.api import TTS
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
|
||||||
|
|
||||||
# load model manager
|
# load model manager
|
||||||
path = Path(__file__).parent / "../.models.json"
|
manager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=args.progress_bar)
|
||||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
|
||||||
|
|
||||||
tts_path = None
|
tts_path = None
|
||||||
tts_config_path = None
|
tts_config_path = None
|
||||||
|
@ -339,12 +337,12 @@ def main():
|
||||||
vc_config_path = None
|
vc_config_path = None
|
||||||
model_dir = None
|
model_dir = None
|
||||||
|
|
||||||
# CASE1 #list : list pre-trained TTS models
|
# 1) List pre-trained TTS models
|
||||||
if args.list_models:
|
if args.list_models:
|
||||||
manager.list_models()
|
manager.list_models()
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# CASE2 #info : model info for pre-trained TTS models
|
# 2) Info about pre-trained TTS models (without loading a model)
|
||||||
if args.model_info_by_idx:
|
if args.model_info_by_idx:
|
||||||
model_query = args.model_info_by_idx
|
model_query = args.model_info_by_idx
|
||||||
manager.model_info_by_idx(model_query)
|
manager.model_info_by_idx(model_query)
|
||||||
|
@ -355,91 +353,50 @@ def main():
|
||||||
manager.model_info_by_full_name(model_query_full_name)
|
manager.model_info_by_full_name(model_query_full_name)
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# CASE3: load pre-trained model paths
|
# 3) Load a model for further info or TTS/VC
|
||||||
if args.model_name is not None and not args.model_path:
|
|
||||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
|
||||||
# tts model
|
|
||||||
if model_item["model_type"] == "tts_models":
|
|
||||||
tts_path = model_path
|
|
||||||
tts_config_path = config_path
|
|
||||||
if args.vocoder_name is None and "default_vocoder" in model_item:
|
|
||||||
args.vocoder_name = model_item["default_vocoder"]
|
|
||||||
|
|
||||||
# voice conversion model
|
|
||||||
if model_item["model_type"] == "voice_conversion_models":
|
|
||||||
vc_path = model_path
|
|
||||||
vc_config_path = config_path
|
|
||||||
|
|
||||||
# tts model with multiple files to be loaded from the directory path
|
|
||||||
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
|
|
||||||
model_dir = model_path
|
|
||||||
tts_path = None
|
|
||||||
tts_config_path = None
|
|
||||||
args.vocoder_name = None
|
|
||||||
|
|
||||||
# load vocoder
|
|
||||||
if args.vocoder_name is not None and not args.vocoder_path:
|
|
||||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
|
||||||
|
|
||||||
# CASE4: set custom model paths
|
|
||||||
if args.model_path is not None:
|
|
||||||
tts_path = args.model_path
|
|
||||||
tts_config_path = args.config_path
|
|
||||||
speakers_file_path = args.speakers_file_path
|
|
||||||
language_ids_file_path = args.language_ids_file_path
|
|
||||||
|
|
||||||
if args.vocoder_path is not None:
|
|
||||||
vocoder_path = args.vocoder_path
|
|
||||||
vocoder_config_path = args.vocoder_config_path
|
|
||||||
|
|
||||||
if args.encoder_path is not None:
|
|
||||||
encoder_path = args.encoder_path
|
|
||||||
encoder_config_path = args.encoder_config_path
|
|
||||||
|
|
||||||
device = args.device
|
device = args.device
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
# A local model will take precedence if specified via modeL_path
|
||||||
# load models
|
model_name = args.model_name if args.model_path is None else None
|
||||||
synthesizer = Synthesizer(
|
api = TTS(
|
||||||
tts_checkpoint=tts_path,
|
model_name=model_name,
|
||||||
tts_config_path=tts_config_path,
|
model_path=args.model_path,
|
||||||
tts_speakers_file=speakers_file_path,
|
config_path=args.config_path,
|
||||||
tts_languages_file=language_ids_file_path,
|
vocoder_name=args.vocoder_name,
|
||||||
vocoder_checkpoint=vocoder_path,
|
vocoder_path=args.vocoder_path,
|
||||||
vocoder_config=vocoder_config_path,
|
vocoder_config_path=args.vocoder_config_path,
|
||||||
encoder_checkpoint=encoder_path,
|
encoder_path=args.encoder_path,
|
||||||
encoder_config=encoder_config_path,
|
encoder_config_path=args.encoder_config_path,
|
||||||
vc_checkpoint=vc_path,
|
speakers_file_path=args.speakers_file_path,
|
||||||
vc_config=vc_config_path,
|
language_ids_file_path=args.language_ids_file_path,
|
||||||
model_dir=model_dir,
|
progress_bar=args.progress_bar,
|
||||||
voice_dir=args.voice_dir,
|
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# query speaker ids of a multi-speaker model.
|
# query speaker ids of a multi-speaker model.
|
||||||
if args.list_speaker_idxs:
|
if args.list_speaker_idxs:
|
||||||
if synthesizer.tts_model.speaker_manager is None:
|
if not api.is_multi_speaker:
|
||||||
logger.info("Model only has a single speaker.")
|
logger.info("Model only has a single speaker.")
|
||||||
return
|
return
|
||||||
logger.info(
|
logger.info(
|
||||||
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||||
)
|
)
|
||||||
logger.info(list(synthesizer.tts_model.speaker_manager.name_to_id.keys()))
|
logger.info(api.speakers)
|
||||||
return
|
return
|
||||||
|
|
||||||
# query langauge ids of a multi-lingual model.
|
# query langauge ids of a multi-lingual model.
|
||||||
if args.list_language_idxs:
|
if args.list_language_idxs:
|
||||||
if synthesizer.tts_model.language_manager is None:
|
if not api.is_multi_lingual:
|
||||||
logger.info("Monolingual model.")
|
logger.info("Monolingual model.")
|
||||||
return
|
return
|
||||||
logger.info(
|
logger.info(
|
||||||
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||||
)
|
)
|
||||||
logger.info(synthesizer.tts_model.language_manager.name_to_id)
|
logger.info(api.languages)
|
||||||
return
|
return
|
||||||
|
|
||||||
# check the arguments against a multi-speaker model.
|
# check the arguments against a multi-speaker model.
|
||||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
|
||||||
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
|
||||||
|
@ -450,31 +407,29 @@ def main():
|
||||||
if args.text:
|
if args.text:
|
||||||
logger.info("Text: %s", args.text)
|
logger.info("Text: %s", args.text)
|
||||||
|
|
||||||
# kick it
|
if args.text is not None:
|
||||||
if tts_path is not None:
|
api.tts_to_file(
|
||||||
wav = synthesizer.tts(
|
text=args.text,
|
||||||
args.text,
|
speaker=args.speaker_idx,
|
||||||
speaker_name=args.speaker_idx,
|
language=args.language_idx,
|
||||||
language_name=args.language_idx,
|
|
||||||
speaker_wav=args.speaker_wav,
|
speaker_wav=args.speaker_wav,
|
||||||
|
pipe_out=pipe_out,
|
||||||
|
file_path=args.out_path,
|
||||||
reference_wav=args.reference_wav,
|
reference_wav=args.reference_wav,
|
||||||
style_wav=args.capacitron_style_wav,
|
style_wav=args.capacitron_style_wav,
|
||||||
style_text=args.capacitron_style_text,
|
style_text=args.capacitron_style_text,
|
||||||
reference_speaker_name=args.reference_speaker_idx,
|
reference_speaker_name=args.reference_speaker_idx,
|
||||||
|
voice_dir=args.voice_dir,
|
||||||
)
|
)
|
||||||
elif vc_path is not None:
|
logger.info("Saved TTS output to %s", args.out_path)
|
||||||
wav = synthesizer.voice_conversion(
|
elif args.source_wav is not None and args.target_wav is not None:
|
||||||
|
api.voice_conversion_to_file(
|
||||||
source_wav=args.source_wav,
|
source_wav=args.source_wav,
|
||||||
target_wav=args.target_wav,
|
target_wav=args.target_wav,
|
||||||
|
file_path=args.out_path,
|
||||||
|
pipe_out=pipe_out,
|
||||||
)
|
)
|
||||||
elif model_dir is not None:
|
logger.info("Saved VC output to %s", args.out_path)
|
||||||
wav = synthesizer.tts(
|
|
||||||
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
|
|
||||||
)
|
|
||||||
|
|
||||||
# save the results
|
|
||||||
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
|
|
||||||
logger.info("Saved output to %s", args.out_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -34,30 +34,27 @@ def run_models(offset=0, step=1):
|
||||||
# download and run the model
|
# download and run the model
|
||||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||||
language_files = glob.glob(local_download_dir + "/language*")
|
language_files = glob.glob(local_download_dir + "/language*")
|
||||||
language_id = ""
|
speaker_arg = ""
|
||||||
|
language_arg = ""
|
||||||
if len(speaker_files) > 0:
|
if len(speaker_files) > 0:
|
||||||
# multi-speaker model
|
# multi-speaker model
|
||||||
if "speaker_ids" in speaker_files[0]:
|
if "speaker_ids" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||||
elif "speakers" in speaker_files[0]:
|
elif "speakers" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||||
|
speakers = list(speaker_manager.name_to_id.keys())
|
||||||
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
|
if len(speakers) > 1:
|
||||||
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
speaker_arg = f'--speaker_idx "{speakers[0]}"'
|
||||||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
||||||
language_id = language_manager.language_names[0]
|
# multi-lingual model
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||||
speaker_id = list(speaker_manager.name_to_id.keys())[0]
|
languages = language_manager.language_names
|
||||||
run_cli(
|
if len(languages) > 1:
|
||||||
f"tts --model_name {model_name} "
|
language_arg = f'--language_idx "{languages[0]}"'
|
||||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --no-progress_bar'
|
run_cli(
|
||||||
)
|
f'tts --model_name {model_name} --text "This is an example." '
|
||||||
else:
|
f'--out_path "{output_path}" {speaker_arg} {language_arg} --no-progress_bar'
|
||||||
# single-speaker model
|
)
|
||||||
run_cli(
|
|
||||||
f"tts --model_name {model_name} "
|
|
||||||
f'--text "This is an example." --out_path "{output_path}" --no-progress_bar'
|
|
||||||
)
|
|
||||||
# remove downloaded models
|
# remove downloaded models
|
||||||
shutil.rmtree(local_download_dir)
|
shutil.rmtree(local_download_dir)
|
||||||
shutil.rmtree(get_user_data_dir("tts"))
|
shutil.rmtree(get_user_data_dir("tts"))
|
||||||
|
|
Loading…
Reference in New Issue