Fix multi-speaker init in Synthesizer

This commit is contained in:
Eren Gölge 2021-12-21 09:44:07 +00:00
parent f769595112
commit c9c1fa0548
2 changed files with 16 additions and 16 deletions

View File

@ -101,7 +101,8 @@ def check_config_and_model_args(config, arg_name, value):
"""Check the give argument in `config.model_args` if exist or in `config` for """Check the give argument in `config.model_args` if exist or in `config` for
the given value. the given value.
It is to patch up the compatibility between models with and without `model_args`. Return False if the argument does not exist in `config.model_args` or `config`.
This is to patch up the compatibility between models with and without `model_args`.
TODO: Remove this in the future with a unified approach. TODO: Remove this in the future with a unified approach.
""" """
@ -110,7 +111,7 @@ def check_config_and_model_args(config, arg_name, value):
return config.model_args[arg_name] == value return config.model_args[arg_name] == value
if hasattr(config, arg_name): if hasattr(config, arg_name):
return config[arg_name] == value return config[arg_name] == value
raise ValueError(f" [!] {arg_name} is not found in config or config.model_args") return False
def get_from_config_or_model_args(config, arg_name): def get_from_config_or_model_args(config, arg_name):

View File

@ -5,7 +5,7 @@ import numpy as np
import pysbd import pysbd
import torch import torch
from TTS.config import load_config from TTS.config import check_config_and_model_args, load_config
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
@ -133,21 +133,23 @@ class Synthesizer(object):
def _is_use_speaker_embedding(self): def _is_use_speaker_embedding(self):
"""Check if the speaker embedding is used in the model""" """Check if the speaker embedding is used in the model"""
# some models use model_args some don't # we handle here the case that some models use model_args some don't
use_speaker_embedding = False
if hasattr(self.tts_config, "model_args"): if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args use_speaker_embedding = self.tts_config["model_args"].get("use_speaker_embedding", False)
else: use_speaker_embedding = use_speaker_embedding or self.tts_config.get("use_speaker_embedding", False)
config = self.tts_config return use_speaker_embedding
return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True
def _is_use_d_vector_file(self): def _is_use_d_vector_file(self):
"""Check if the d-vector file is used in the model""" """Check if the d-vector file is used in the model"""
# some models use model_args some don't # we handle here the case that some models use model_args some don't
use_d_vector_file = False
if hasattr(self.tts_config, "model_args"): if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args config = self.tts_config.model_args
else: use_d_vector_file = config.get("use_d_vector_file", False)
config = self.tts_config config = self.tts_config
return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True use_d_vector_file = use_d_vector_file or config.get("use_d_vector_file", False)
return use_d_vector_file
def _init_speaker_manager(self): def _init_speaker_manager(self):
"""Initialize the SpeakerManager""" """Initialize the SpeakerManager"""
@ -176,10 +178,7 @@ class Synthesizer(object):
"""Initialize the LanguageManager""" """Initialize the LanguageManager"""
# setup if multi-lingual settings are in the global model config # setup if multi-lingual settings are in the global model config
language_manager = None language_manager = None
if ( if check_config_and_model_args(self.tts_config, "use_language_embedding", True):
hasattr(self.tts_config.model_args, "use_language_embedding")
and self.tts_config.model_args.use_language_embedding is True
):
if self.tts_languages_file: if self.tts_languages_file:
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
elif self.tts_config.get("language_ids_file", None): elif self.tts_config.get("language_ids_file", None):