From 61874bc0a0c703df56b55e51b9d67784575f93f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 31 Dec 2021 13:45:05 +0000 Subject: [PATCH] Fix your_tts inference from the listed models --- TTS/config/__init__.py | 10 ++++++++++ TTS/utils/manage.py | 8 ++++++++ TTS/utils/synthesizer.py | 19 ++++++++++++++----- tests/zoo_tests/test_models.py | 11 ++++++++++- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py index 8ed3578f..5c905295 100644 --- a/TTS/config/__init__.py +++ b/TTS/config/__init__.py @@ -120,3 +120,13 @@ def get_from_config_or_model_args(config, arg_name): if arg_name in config.model_args: return config.model_args[arg_name] return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index d1dedbe0..7ad596f0 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -151,6 +151,8 @@ class ModelManager(object): output_stats_path = os.path.join(output_path, "scale_stats.npy") output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar") # update the scale_path.npy file path in the model config.json self._update_path("audio.stats_path", output_stats_path, config_path) @@ -163,6 +165,12 @@ class ModelManager(object): self._update_path("speakers_file", output_speaker_ids_file_path, config_path) self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + @staticmethod def _update_path(field_name, new_path, config_path): """Update the path in the model config.json for the current environment after download""" diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 7a2d3097..66579a1b 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,7 +5,7 @@ import numpy as np import pysbd import torch -from TTS.config import check_config_and_model_args, load_config +from TTS.config import check_config_and_model_args, load_config, get_from_config_or_model_args_with_default from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager @@ -117,6 +117,7 @@ class Synthesizer(object): speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() + self._set_speaker_encoder_paths_from_tts_config() speaker_manager = self._init_speaker_encoder(speaker_manager) if language_manager is not None: @@ -131,6 +132,12 @@ class Synthesizer(object): if use_cuda: self.tts_model.cuda() + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "speaker_encoder_config_path"): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + def _is_use_speaker_embedding(self): """Check if the speaker embedding is used in the model""" # we handle here the case that some models use model_args some don't @@ -155,17 +162,19 @@ class Synthesizer(object): """Initialize the SpeakerManager""" # setup if multi-speaker settings are in the global model config speaker_manager = None + speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None) if self._is_use_speaker_embedding(): if self.tts_speakers_file: speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) - if self.tts_config.get("speakers_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) + if speakers_file: + speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file) if self._is_use_d_vector_file(): + d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None) if self.tts_speakers_file: speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) - if self.tts_config.get("d_vector_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) + if d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file) return speaker_manager def _init_speaker_encoder(self, speaker_manager): diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index 886d1bb6..43273572 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -2,6 +2,7 @@ import glob import os import shutil +from TTS.tts.utils.languages import LanguageManager from tests import get_tests_output_path, run_cli from TTS.tts.utils.speakers import SpeakerManager @@ -22,16 +23,24 @@ def test_run_all_models(): local_download_dir = os.path.dirname(model_path) # download and run the model speaker_files = glob.glob(local_download_dir + "/speaker*") + language_files = glob.glob(local_download_dir + "/language*") + language_id = "" if len(speaker_files) > 0: # multi-speaker model if "speaker_ids" in speaker_files[0]: speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) elif "speakers" in speaker_files[0]: speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) + + # multi-lingual model - Assuming multi-lingual models are also multi-speaker + if len(language_files) > 0 and "language_ids" in language_files[0]: + language_manager = LanguageManager(language_ids_file_path=language_files[0]) + language_id = language_manager.language_names[0] + speaker_id = list(speaker_manager.speaker_ids.keys())[0] run_cli( f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"' + f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" ' ) else: # single-speaker model