Make cloning configurable

This commit is contained in:
Eren G??lge 2023-11-04 13:41:03 +01:00
parent c182535e2a
commit b1b6876489
1 changed files with 12 additions and 4 deletions

View File

@ -189,7 +189,8 @@ class XttsArgs(Coqpit):
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
For GPT model: For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -373,6 +374,7 @@ class Xtts(BaseTTS):
Args: Args:
audio_path (str): Path to the audio file. audio_path (str): Path to the audio file.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3. length (int): Length of the audio in seconds. Defaults to 3.
""" """
@ -505,6 +507,9 @@ class Xtts(BaseTTS):
"diffusion_temperature": config.diffusion_temperature, "diffusion_temperature": config.diffusion_temperature,
"decoder_iterations": config.decoder_iterations, "decoder_iterations": config.decoder_iterations,
"decoder_sampler": config.decoder_sampler, "decoder_sampler": config.decoder_sampler,
"gpt_cond_len": config.gpt_cond_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
} }
settings.update(kwargs) # allow overriding of preset settings with kwargs settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.full_inference(text, ref_audio_path, language, **settings) return self.full_inference(text, ref_audio_path, language, **settings)
@ -521,8 +526,11 @@ class Xtts(BaseTTS):
repetition_penalty=2.0, repetition_penalty=2.0,
top_k=50, top_k=50,
top_p=0.85, top_p=0.85,
gpt_cond_len=6,
do_sample=True, do_sample=True,
# Cloning
gpt_cond_len=6,
max_ref_len=10,
sound_norm_refs=False,
# Decoder inference # Decoder inference
decoder_iterations=100, decoder_iterations=100,
cond_free=True, cond_free=True,
@ -590,7 +598,7 @@ class Xtts(BaseTTS):
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents( (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs
) )
return self.inference( return self.inference(
text, text,