Make cloning configurable

This commit is contained in:
Eren G??lge 2023-11-04 13:41:03 +01:00
parent c182535e2a
commit b1b6876489
1 changed files with 12 additions and 4 deletions

View File

@ -25,7 +25,7 @@ init_stream_support()
def wav_to_mel_cloning(
wav,
mel_norms_file="../experiments/clips_mel_norms.pth",
mel_norms=None,
mel_norms=None,
device=torch.device("cpu"),
n_fft=4096,
hop_length=1024,
@ -189,7 +189,8 @@ class XttsArgs(Coqpit):
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True.
use_ne_hifigan (bool, optional): Whether to use regular hifigan or diffusion + univnet as a decoder. Defaults to False.
For GPT model:
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -373,6 +374,7 @@ class Xtts(BaseTTS):
Args:
audio_path (str): Path to the audio file.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
"""
@ -505,6 +507,9 @@ class Xtts(BaseTTS):
"diffusion_temperature": config.diffusion_temperature,
"decoder_iterations": config.decoder_iterations,
"decoder_sampler": config.decoder_sampler,
"gpt_cond_len": config.gpt_cond_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.full_inference(text, ref_audio_path, language, **settings)
@ -521,8 +526,11 @@ class Xtts(BaseTTS):
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
gpt_cond_len=6,
do_sample=True,
# Cloning
gpt_cond_len=6,
max_ref_len=10,
sound_norm_refs=False,
# Decoder inference
decoder_iterations=100,
cond_free=True,
@ -590,7 +598,7 @@ class Xtts(BaseTTS):
Sample rate is 24kHz.
"""
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, max_ref_length=max_ref_len, sound_norm_refs=sound_norm_refs
)
return self.inference(
text,