Update XTTS docs

This commit is contained in:
Edresson Casanova 2023-11-01 13:53:16 -03:00 committed by Eren G??lge
parent 8479a3702c
commit 5df8f76b0c
2 changed files with 20 additions and 16 deletions

View File

@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length self.max_text_len = model_args.max_text_length
self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach
assert self.max_wav_len is not None and self.max_text_len is not None assert self.max_wav_len is not None and self.max_text_len is not None
self.samples = samples self.samples = samples
@ -141,7 +141,7 @@ class XTTSDataset(torch.utils.data.Dataset):
# Ultra short clips are also useless (and can cause problems within some models). # Ultra short clips are also useless (and can cause problems within some models).
raise ValueError raise ValueError
if self.use_masking_gt_as_prompt: if self.use_masking_gt_prompt_approach:
# get a slice from GT to condition the model # get a slice from GT to condition the model
cond, _, cond_idxs = get_prompt_slice( cond, _, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval

View File

@ -112,7 +112,7 @@ def load_discrete_vocoder_diffuser(
return SpacedDiffusion( return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon", model_mean_type="epsilon",
model_var_type="learned_range", model_vgpt_type="learned_range",
loss_type="mse", loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps), betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free, conditioning_free=cond_free,
@ -192,16 +192,19 @@ class XttsArgs(Coqpit):
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
For GPT model: For GPT model:
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
For DiffTTS model: For DiffTTS model:
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
@ -241,7 +244,8 @@ class XttsArgs(Coqpit):
gpt_num_audio_tokens: int = 8194 gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192 gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193 gpt_stop_audio_token: int = 8193
gpt_use_masking_gt_as_prompt: bool = True gpt_code_stride_len: int = 1024
gpt_use_masking_gt_prompt_approach: bool = True
gpt_use_perceiver_resampler: bool = False gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params # Diffusion Decoder params
@ -261,7 +265,6 @@ class XttsArgs(Coqpit):
input_sample_rate: int = 22050 input_sample_rate: int = 22050
output_sample_rate: int = 24000 output_sample_rate: int = 24000
output_hop_length: int = 256 output_hop_length: int = 256
ar_mel_length_compression: int = 1024
decoder_input_dim: int = 1024 decoder_input_dim: int = 1024
d_vector_dim: int = 512 d_vector_dim: int = 512
cond_d_vector_in_each_upsampling_layer: bool = True cond_d_vector_in_each_upsampling_layer: bool = True
@ -319,6 +322,7 @@ class Xtts(BaseTTS):
start_audio_token=self.args.gpt_start_audio_token, start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token, stop_audio_token=self.args.gpt_stop_audio_token,
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler, use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
code_stride_len=self.args.gpt_code_stride_len,
) )
if self.args.use_hifigan: if self.args.use_hifigan:
@ -326,7 +330,7 @@ class Xtts(BaseTTS):
input_sample_rate=self.args.input_sample_rate, input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate, output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length, output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression, ar_mel_length_compression=self.args.gpt_code_stride_len,
decoder_input_dim=self.args.decoder_input_dim, decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim, d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
@ -337,7 +341,7 @@ class Xtts(BaseTTS):
input_sample_rate=self.args.input_sample_rate, input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate, output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length, output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression, ar_mel_length_compression=self.args.gpt_code_stride_len,
decoder_input_dim=self.args.decoder_input_dim, decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim, d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,