small fixes

This commit is contained in:
WeberJulian 2023-10-06 14:40:42 +02:00
parent be51205607
commit a097541ed4
2 changed files with 4 additions and 5 deletions

View File

@ -593,7 +593,6 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
use_hifigan=True,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
@ -765,7 +764,7 @@ class Xtts(BaseTTS):
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None, checkpoint_path=None,
vocab_path=None, vocab_path=None,
eval=False, eval=True,
strict=True, strict=True,
use_deepspeed=False, use_deepspeed=False,
): ):
@ -777,7 +776,7 @@ class Xtts(BaseTTS):
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False. eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
Returns: Returns:

View File

@ -71,7 +71,7 @@ print("Loading model...")
config = XttsConfig() config = XttsConfig()
config.load_json("/path/to/xtts/config.json") config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config) model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True, use_deepspeed=True) model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
@ -108,7 +108,7 @@ print("Loading model...")
config = XttsConfig() config = XttsConfig()
config.load_json("/path/to/xtts/config.json") config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config) model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True, use_deepspeed=True) model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")