mirror of https://github.com/coqui-ai/TTS.git
Fix your_tts inference from the listed models
This commit is contained in:
parent
8100135a7e
commit
61874bc0a0
|
@ -120,3 +120,13 @@ def get_from_config_or_model_args(config, arg_name):
|
|||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
return config[arg_name]
|
||||
|
||||
|
||||
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||
if hasattr(config, "model_args"):
|
||||
if arg_name in config.model_args:
|
||||
return config.model_args[arg_name]
|
||||
if hasattr(config, arg_name):
|
||||
return config[arg_name]
|
||||
return def_val
|
||||
|
|
|
@ -151,6 +151,8 @@ class ModelManager(object):
|
|||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
@ -163,6 +165,12 @@ class ModelManager(object):
|
|||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
|
||||
# update the speaker_encoder file path in the model config.json to the current path
|
||||
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
import pysbd
|
||||
import torch
|
||||
|
||||
from TTS.config import check_config_and_model_args, load_config
|
||||
from TTS.config import check_config_and_model_args, load_config, get_from_config_or_model_args_with_default
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
@ -117,6 +117,7 @@ class Synthesizer(object):
|
|||
|
||||
speaker_manager = self._init_speaker_manager()
|
||||
language_manager = self._init_language_manager()
|
||||
self._set_speaker_encoder_paths_from_tts_config()
|
||||
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||
|
||||
if language_manager is not None:
|
||||
|
@ -131,6 +132,12 @@ class Synthesizer(object):
|
|||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||
if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "speaker_encoder_config_path"):
|
||||
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||
|
||||
def _is_use_speaker_embedding(self):
|
||||
"""Check if the speaker embedding is used in the model"""
|
||||
# we handle here the case that some models use model_args some don't
|
||||
|
@ -155,17 +162,19 @@ class Synthesizer(object):
|
|||
"""Initialize the SpeakerManager"""
|
||||
# setup if multi-speaker settings are in the global model config
|
||||
speaker_manager = None
|
||||
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
|
||||
if self._is_use_speaker_embedding():
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("speakers_file", None):
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
||||
if speakers_file:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
|
||||
|
||||
if self._is_use_d_vector_file():
|
||||
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("d_vector_file", None):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||
if d_vector_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
|
||||
return speaker_manager
|
||||
|
||||
def _init_speaker_encoder(self, speaker_manager):
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
@ -22,16 +23,24 @@ def test_run_all_models():
|
|||
local_download_dir = os.path.dirname(model_path)
|
||||
# download and run the model
|
||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||
language_files = glob.glob(local_download_dir + "/language*")
|
||||
language_id = ""
|
||||
if len(speaker_files) > 0:
|
||||
# multi-speaker model
|
||||
if "speaker_ids" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||
elif "speakers" in speaker_files[0]:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||
|
||||
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
|
||||
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
||||
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||
language_id = language_manager.language_names[0]
|
||||
|
||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||
)
|
||||
else:
|
||||
# single-speaker model
|
||||
|
|
Loading…
Reference in New Issue