Fix your_tts inference from the listed models

This commit is contained in:
Eren Gölge 2021-12-31 13:45:05 +00:00
parent 8100135a7e
commit 61874bc0a0
4 changed files with 42 additions and 6 deletions

View File

@ -120,3 +120,13 @@ def get_from_config_or_model_args(config, arg_name):
if arg_name in config.model_args: if arg_name in config.model_args:
return config.model_args[arg_name] return config.model_args[arg_name]
return config[arg_name] return config[arg_name]
def get_from_config_or_model_args_with_default(config, arg_name, def_val):
"""Get the given argument from `config.model_args` if exist or in `config`."""
if hasattr(config, "model_args"):
if arg_name in config.model_args:
return config.model_args[arg_name]
if hasattr(config, arg_name):
return config[arg_name]
return def_val

View File

@ -151,6 +151,8 @@ class ModelManager(object):
output_stats_path = os.path.join(output_path, "scale_stats.npy") output_stats_path = os.path.join(output_path, "scale_stats.npy")
output_d_vector_file_path = os.path.join(output_path, "speakers.json") output_d_vector_file_path = os.path.join(output_path, "speakers.json")
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
speaker_encoder_model_path = os.path.join(output_path, "model_se.pth.tar")
# update the scale_path.npy file path in the model config.json # update the scale_path.npy file path in the model config.json
self._update_path("audio.stats_path", output_stats_path, config_path) self._update_path("audio.stats_path", output_stats_path, config_path)
@ -163,6 +165,12 @@ class ModelManager(object):
self._update_path("speakers_file", output_speaker_ids_file_path, config_path) self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
# update the speaker_encoder file path in the model config.json to the current path
self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)
@staticmethod @staticmethod
def _update_path(field_name, new_path, config_path): def _update_path(field_name, new_path, config_path):
"""Update the path in the model config.json for the current environment after download""" """Update the path in the model config.json for the current environment after download"""

View File

@ -5,7 +5,7 @@ import numpy as np
import pysbd import pysbd
import torch import torch
from TTS.config import check_config_and_model_args, load_config from TTS.config import check_config_and_model_args, load_config, get_from_config_or_model_args_with_default
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
@ -117,6 +117,7 @@ class Synthesizer(object):
speaker_manager = self._init_speaker_manager() speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager() language_manager = self._init_language_manager()
self._set_speaker_encoder_paths_from_tts_config()
speaker_manager = self._init_speaker_encoder(speaker_manager) speaker_manager = self._init_speaker_encoder(speaker_manager)
if language_manager is not None: if language_manager is not None:
@ -131,6 +132,12 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
def _set_speaker_encoder_paths_from_tts_config(self):
"""Set the encoder paths from the tts model config for models with speaker encoders."""
if hasattr(self.tts_config, "model_args") and hasattr(self.tts_config.model_args, "speaker_encoder_config_path"):
self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path
self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path
def _is_use_speaker_embedding(self): def _is_use_speaker_embedding(self):
"""Check if the speaker embedding is used in the model""" """Check if the speaker embedding is used in the model"""
# we handle here the case that some models use model_args some don't # we handle here the case that some models use model_args some don't
@ -155,17 +162,19 @@ class Synthesizer(object):
"""Initialize the SpeakerManager""" """Initialize the SpeakerManager"""
# setup if multi-speaker settings are in the global model config # setup if multi-speaker settings are in the global model config
speaker_manager = None speaker_manager = None
speakers_file = get_from_config_or_model_args_with_default(self.tts_config, "speakers_file", None)
if self._is_use_speaker_embedding(): if self._is_use_speaker_embedding():
if self.tts_speakers_file: if self.tts_speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
if self.tts_config.get("speakers_file", None): if speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) speaker_manager = SpeakerManager(speaker_id_file_path=speakers_file)
if self._is_use_d_vector_file(): if self._is_use_d_vector_file():
d_vector_file = get_from_config_or_model_args_with_default(self.tts_config, "d_vector_file", None)
if self.tts_speakers_file: if self.tts_speakers_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
if self.tts_config.get("d_vector_file", None): if d_vector_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) speaker_manager = SpeakerManager(d_vectors_file_path=d_vector_file)
return speaker_manager return speaker_manager
def _init_speaker_encoder(self, speaker_manager): def _init_speaker_encoder(self, speaker_manager):

View File

@ -2,6 +2,7 @@
import glob import glob
import os import os
import shutil import shutil
from TTS.tts.utils.languages import LanguageManager
from tests import get_tests_output_path, run_cli from tests import get_tests_output_path, run_cli
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
@ -22,16 +23,24 @@ def test_run_all_models():
local_download_dir = os.path.dirname(model_path) local_download_dir = os.path.dirname(model_path)
# download and run the model # download and run the model
speaker_files = glob.glob(local_download_dir + "/speaker*") speaker_files = glob.glob(local_download_dir + "/speaker*")
language_files = glob.glob(local_download_dir + "/language*")
language_id = ""
if len(speaker_files) > 0: if len(speaker_files) > 0:
# multi-speaker model # multi-speaker model
if "speaker_ids" in speaker_files[0]: if "speaker_ids" in speaker_files[0]:
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0]) speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
elif "speakers" in speaker_files[0]: elif "speakers" in speaker_files[0]:
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0]) speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
if len(language_files) > 0 and "language_ids" in language_files[0]:
language_manager = LanguageManager(language_ids_file_path=language_files[0])
language_id = language_manager.language_names[0]
speaker_id = list(speaker_manager.speaker_ids.keys())[0] speaker_id = list(speaker_manager.speaker_ids.keys())[0]
run_cli( run_cli(
f"tts --model_name {model_name} " f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}"' f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
) )
else: else:
# single-speaker model # single-speaker model