This commit is contained in:
bivashy 2024-02-10 15:55:55 +01:00 committed by GitHub
commit d656b7dc20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 10 deletions

View File

@ -113,15 +113,15 @@ synthesizer = Synthesizer(
use_cuda=args.use_cuda,
)
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
)
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
use_multi_speaker = (hasattr(synthesizer.tts_model, "num_speakers") and (
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
)) or (speaker_manager is not None)
use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and (
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
)
language_manager = getattr(synthesizer.tts_model, "language_manager", None)
use_multi_language = (hasattr(synthesizer.tts_model, "num_languages") and (
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
)) or (language_manager is not None)
# TODO: set this from SpeakerManager
use_gst = synthesizer.tts_config.get("use_gst", False)

View File

@ -731,8 +731,8 @@ class Xtts(BaseTTS):
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_dir=None,
vocab_path=None,
eval=True,
strict=True,
@ -744,8 +744,8 @@ class Xtts(BaseTTS):
Args:
config (dict): The configuration dictionary for the model.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
@ -753,7 +753,8 @@ class Xtts(BaseTTS):
Returns:
None
"""
if checkpoint_dir is None and checkpoint_path:
checkpoint_dir = os.path.dirname(checkpoint_path)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")

View File

@ -414,7 +414,7 @@ class ModelManager(object):
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts1" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json