mirror of https://github.com/coqui-ai/TTS.git
Merge 3ebc28608c
into dbf1a08a0d
This commit is contained in:
commit
d656b7dc20
|
@ -113,15 +113,15 @@ synthesizer = Synthesizer(
|
||||||
use_cuda=args.use_cuda,
|
use_cuda=args.use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
|
|
||||||
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
|
|
||||||
)
|
|
||||||
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
|
speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
|
||||||
|
use_multi_speaker = (hasattr(synthesizer.tts_model, "num_speakers") and (
|
||||||
|
synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
|
||||||
|
)) or (speaker_manager is not None)
|
||||||
|
|
||||||
use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and (
|
|
||||||
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
|
|
||||||
)
|
|
||||||
language_manager = getattr(synthesizer.tts_model, "language_manager", None)
|
language_manager = getattr(synthesizer.tts_model, "language_manager", None)
|
||||||
|
use_multi_language = (hasattr(synthesizer.tts_model, "num_languages") and (
|
||||||
|
synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
|
||||||
|
)) or (language_manager is not None)
|
||||||
|
|
||||||
# TODO: set this from SpeakerManager
|
# TODO: set this from SpeakerManager
|
||||||
use_gst = synthesizer.tts_config.get("use_gst", False)
|
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||||
|
|
|
@ -731,8 +731,8 @@ class Xtts(BaseTTS):
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir=None,
|
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
|
checkpoint_dir=None,
|
||||||
vocab_path=None,
|
vocab_path=None,
|
||||||
eval=True,
|
eval=True,
|
||||||
strict=True,
|
strict=True,
|
||||||
|
@ -744,8 +744,8 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (dict): The configuration dictionary for the model.
|
config (dict): The configuration dictionary for the model.
|
||||||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
|
||||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||||
|
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
||||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||||
|
@ -753,7 +753,8 @@ class Xtts(BaseTTS):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
if checkpoint_dir is None and checkpoint_path:
|
||||||
|
checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||||
|
|
||||||
|
|
|
@ -414,7 +414,7 @@ class ModelManager(object):
|
||||||
output_model_path = output_path
|
output_model_path = output_path
|
||||||
output_config_path = None
|
output_config_path = None
|
||||||
if (
|
if (
|
||||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
|
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts1" not in model_name
|
||||||
): # TODO:This is stupid but don't care for now.
|
): # TODO:This is stupid but don't care for now.
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
|
|
Loading…
Reference in New Issue