Add lang code in XTTS doc (#3158)

* Add lang code in XTTS doc

* Remove ununsed config and args

* update docs

* woops
This commit is contained in:
Julian Weber 2023-11-08 13:47:33 +01:00 committed by GitHub
parent 78a596618a
commit 03ad90135b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 66 deletions

View File

@ -37,29 +37,11 @@ class XttsConfig(BaseTTSConfig):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
Defaults to `0.8`. Defaults to `0.8`.
cond_free_k (float):
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
diffusion_temperature (float):
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
are the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to `1.0`.
num_gpt_outputs (int): num_gpt_outputs (int):
Number of samples taken from the autoregressive model, all of which are filtered using CLVP. Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great". As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
Defaults to `16`. Defaults to `16`.
decoder_iterations (int):
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
however. Defaults to `30`.
decoder_sampler (str):
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
gpt_cond_len (int): gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`. Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
@ -110,11 +92,7 @@ class XttsConfig(BaseTTSConfig):
repetition_penalty: float = 2.0 repetition_penalty: float = 2.0
top_k: int = 50 top_k: int = 50
top_p: float = 0.85 top_p: float = 0.85
cond_free_k: float = 2.0
diffusion_temperature: float = 1.0
num_gpt_outputs: int = 1 num_gpt_outputs: int = 1
decoder_iterations: int = 30
decoder_sampler: str = "ddim"
# cloning # cloning
gpt_cond_len: int = 3 gpt_cond_len: int = 3

View File

@ -152,19 +152,6 @@ class XttsArgs(Coqpit):
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
For DiffTTS model:
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
""" """
gpt_batch_size: int = 1 gpt_batch_size: int = 1
@ -193,19 +180,6 @@ class XttsArgs(Coqpit):
gpt_use_masking_gt_prompt_approach: bool = True gpt_use_masking_gt_prompt_approach: bool = True
gpt_use_perceiver_resampler: bool = False gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params
diff_model_channels: int = 1024
diff_num_layers: int = 10
diff_in_channels: int = 100
diff_out_channels: int = 200
diff_in_latent_channels: int = 1024
diff_in_tokens: int = 8193
diff_dropout: int = 0
diff_use_fp16: bool = False
diff_num_heads: int = 16
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# HifiGAN Decoder params # HifiGAN Decoder params
input_sample_rate: int = 22050 input_sample_rate: int = 22050
output_sample_rate: int = 24000 output_sample_rate: int = 24000
@ -426,10 +400,6 @@ class Xtts(BaseTTS):
"repetition_penalty": config.repetition_penalty, "repetition_penalty": config.repetition_penalty,
"top_k": config.top_k, "top_k": config.top_k,
"top_p": config.top_p, "top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"decoder_iterations": config.decoder_iterations,
"decoder_sampler": config.decoder_sampler,
"gpt_cond_len": config.gpt_cond_len, "gpt_cond_len": config.gpt_cond_len,
"max_ref_len": config.max_ref_len, "max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs, "sound_norm_refs": config.sound_norm_refs,
@ -454,13 +424,6 @@ class Xtts(BaseTTS):
gpt_cond_len=6, gpt_cond_len=6,
max_ref_len=10, max_ref_len=10,
sound_norm_refs=False, sound_norm_refs=False,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
""" """

View File

@ -24,8 +24,7 @@ a few tricks to make it faster and support streaming inference.
Current implementation only supports inference. Current implementation only supports inference.
### Languages ### Languages
As of now, XTTS-v2 supports 16 languages: English, Spanish, French, German, Italian, Portuguese, As of now, XTTS-v2 supports 16 languages: English (en), Spanish (es), French (fr), German (de), Italian (it), Portuguese (pt), Polish (pl), Turkish (tr), Russian (ru), Dutch (nl), Czech (cs), Arabic (ar), Chinese (zh-cn), Japanese (ja), Hungarian (hu) and Korean (ko).
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified), Japanese, Hungarian, Korean
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out. Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
@ -116,7 +115,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"]) gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...") print("Inference...")
out = model.inference( out = model.inference(
@ -124,7 +123,6 @@ out = model.inference(
"en", "en",
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here temperature=0.7, # Add custom parameters here
) )
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
@ -153,7 +151,7 @@ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=Tru
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"]) gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["reference.wav"])
print("Inference...") print("Inference...")
t0 = time.time() t0 = time.time()
@ -210,7 +208,7 @@ model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENI
model.cuda() model.cuda()
print("Computing speaker latents...") print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE]) gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=[SPEAKER_REFERENCE])
print("Inference...") print("Inference...")
out = model.inference( out = model.inference(
@ -218,7 +216,6 @@ out = model.inference(
"en", "en",
gpt_cond_latent, gpt_cond_latent,
speaker_embedding, speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here temperature=0.7, # Add custom parameters here
) )
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000) torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)