diff --git a/TTS/server/server.py b/TTS/server/server.py index 6b2141a9..66b7dcb2 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -113,15 +113,15 @@ synthesizer = Synthesizer( use_cuda=args.use_cuda, ) -use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( - synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None -) speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) +use_multi_speaker = (hasattr(synthesizer.tts_model, "num_speakers") and ( + synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None +)) or (speaker_manager is not None) -use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and ( - synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None -) language_manager = getattr(synthesizer.tts_model, "language_manager", None) +use_multi_language = (hasattr(synthesizer.tts_model, "num_languages") and ( + synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None +)) or (language_manager is not None) # TODO: set this from SpeakerManager use_gst = synthesizer.tts_config.get("use_gst", False) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 8e9d6bd3..6b4cdb8a 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -731,8 +731,8 @@ class Xtts(BaseTTS): def load_checkpoint( self, config, - checkpoint_dir=None, checkpoint_path=None, + checkpoint_dir=None, vocab_path=None, eval=True, strict=True, @@ -744,8 +744,8 @@ class Xtts(BaseTTS): Args: config (dict): The configuration dictionary for the model. - checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None. eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. @@ -753,7 +753,8 @@ class Xtts(BaseTTS): Returns: None """ - + if checkpoint_dir is None and checkpoint_path: + checkpoint_dir = os.path.dirname(checkpoint_path) model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3a527f46..e8c009df 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -414,7 +414,7 @@ class ModelManager(object): output_model_path = output_path output_config_path = None if ( - model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts1" not in model_name ): # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json