mirror of https://github.com/coqui-ai/TTS.git
Fix your_tts inference from the listed models
This commit is contained in:
parent
8100135a7e
commit
61874bc0a0
|
@ -120,3 +120,13 @@ def get_from_config_or_model_args(config, arg_name):
|
||||||
if arg_name in config.model_args:
|
if arg_name in config.model_args:
|
||||||
return config.model_args[arg_name]
|
return config.model_args[arg_name]
|
||||||
return config[arg_name]
|
return config[arg_name]
|
||||||
|
|
||||||
|
|
||||||
|
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
|
||||||
|
"""Get the given argument from `config.model_args` if exist or in `config`."""
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
if arg_name in config.model_args:
|
||||||
|
return config.model_args[arg_name]
|
||||||
|
if hasattr(config, arg_name):
|
||||||
|
return config[arg_name]
|
||||||
|
return def_val
|
||||||
|
|
|
@ -151,6 +151,8 @@ class ModelManager(object):
|
||||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||||
|
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
|
||||||
|
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
|
||||||
|
|
||||||
# update the scale_path.npy file path in the model config.json
|
# update the scale_path.npy file path in the model config.json
|
||||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||||
|
@ -163,6 +165,12 @@ class ModelManager(object):
|
||||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||||
|
|
||||||
|
# update the speaker_encoder file path in the model config.json to the current path
|
||||||
|
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||||
|
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
|
||||||
|
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||||
|
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_path(field_name, new_path, config_path):
|
def _update_path(field_name, new_path, config_path):
|
||||||
"""Update the path in the model config.json for the current environment after download"""
|
"""Update the path in the model config.json for the current environment after download"""
|
||||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config import check_config_and_model_args, load_config
|
from TTS.config import check_config_and_model_args, load_config, get_from_config_or_model_args_with_default
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
@ -117,6 +117,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
speaker_manager = self._init_speaker_manager()
|
speaker_manager = self._init_speaker_manager()
|
||||||
language_manager = self._init_language_manager()
|
language_manager = self._init_language_manager()
|
||||||
|
self._set_speaker_encoder_paths_from_tts_config()
|
||||||
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||||
|
|
||||||
if language_manager is not None:
|
if language_manager is not None:
|
||||||
|
@ -131,6 +132,12 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
|
def _set_speaker_encoder_paths_from_tts_config(self):
|
||||||
|
"""Set the encoder paths from the tts model config for models with speaker encoders."""
|
||||||
|
if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "speaker_encoder_config_path"):
|
||||||
|
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
|
||||||
|
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
|
||||||
|
|
||||||
def _is_use_speaker_embedding(self):
|
def _is_use_speaker_embedding(self):
|
||||||
"""Check if the speaker embedding is used in the model"""
|
"""Check if the speaker embedding is used in the model"""
|
||||||
# we handle here the case that some models use model_args some don't
|
# we handle here the case that some models use model_args some don't
|
||||||
|
@ -155,17 +162,19 @@ class Synthesizer(object):
|
||||||
"""Initialize the SpeakerManager"""
|
"""Initialize the SpeakerManager"""
|
||||||
# setup if multi-speaker settings are in the global model config
|
# setup if multi-speaker settings are in the global model config
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
|
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
|
||||||
if self._is_use_speaker_embedding():
|
if self._is_use_speaker_embedding():
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("speakers_file", None):
|
if speakers_file:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
|
||||||
|
|
||||||
if self._is_use_d_vector_file():
|
if self._is_use_d_vector_file():
|
||||||
|
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("d_vector_file", None):
|
if d_vector_file:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
def _init_speaker_encoder(self, speaker_manager):
|
def _init_speaker_encoder(self, speaker_manager):
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
|
||||||
from tests import get_tests_output_path, run_cli
|
from tests import get_tests_output_path, run_cli
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
@ -22,16 +23,24 @@ def test_run_all_models():
|
||||||
local_download_dir = os.path.dirname(model_path)
|
local_download_dir = os.path.dirname(model_path)
|
||||||
# download and run the model
|
# download and run the model
|
||||||
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
speaker_files = glob.glob(local_download_dir + "/speaker*")
|
||||||
|
language_files = glob.glob(local_download_dir + "/language*")
|
||||||
|
language_id = ""
|
||||||
if len(speaker_files) > 0:
|
if len(speaker_files) > 0:
|
||||||
# multi-speaker model
|
# multi-speaker model
|
||||||
if "speaker_ids" in speaker_files[0]:
|
if "speaker_ids" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
|
||||||
elif "speakers" in speaker_files[0]:
|
elif "speakers" in speaker_files[0]:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
|
||||||
|
|
||||||
|
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
|
||||||
|
if len(language_files) > 0 and "language_ids" in language_files[0]:
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=language_files[0])
|
||||||
|
language_id = language_manager.language_names[0]
|
||||||
|
|
||||||
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
speaker_id = list(speaker_manager.speaker_ids.keys())[0]
|
||||||
run_cli(
|
run_cli(
|
||||||
f"tts --model_name {model_name} "
|
f"tts --model_name {model_name} "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"'
|
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# single-speaker model
|
# single-speaker model
|
||||||
|
|
Loading…
Reference in New Issue