mirror of https://github.com/coqui-ai/TTS.git
Update XTTS docs
This commit is contained in:
parent
8479a3702c
commit
5df8f76b0c
|
@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_wav_len = model_args.max_wav_length
|
self.max_wav_len = model_args.max_wav_length
|
||||||
self.max_text_len = model_args.max_text_length
|
self.max_text_len = model_args.max_text_length
|
||||||
self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt
|
self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
|
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
@ -141,7 +141,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
# Ultra short clips are also useless (and can cause problems within some models).
|
# Ultra short clips are also useless (and can cause problems within some models).
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if self.use_masking_gt_as_prompt:
|
if self.use_masking_gt_prompt_approach:
|
||||||
# get a slice from GT to condition the model
|
# get a slice from GT to condition the model
|
||||||
cond, _, cond_idxs = get_prompt_slice(
|
cond, _, cond_idxs = get_prompt_slice(
|
||||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
|
|
|
@ -112,7 +112,7 @@ def load_discrete_vocoder_diffuser(
|
||||||
return SpacedDiffusion(
|
return SpacedDiffusion(
|
||||||
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
||||||
model_mean_type="epsilon",
|
model_mean_type="epsilon",
|
||||||
model_var_type="learned_range",
|
model_vgpt_type="learned_range",
|
||||||
loss_type="mse",
|
loss_type="mse",
|
||||||
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
|
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
|
||||||
conditioning_free=cond_free,
|
conditioning_free=cond_free,
|
||||||
|
@ -192,16 +192,19 @@ class XttsArgs(Coqpit):
|
||||||
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
||||||
|
|
||||||
For GPT model:
|
For GPT model:
|
||||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||||
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
||||||
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
||||||
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
||||||
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
||||||
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
||||||
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
||||||
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
||||||
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
||||||
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
||||||
|
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
|
||||||
|
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
|
||||||
|
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
|
||||||
|
|
||||||
For DiffTTS model:
|
For DiffTTS model:
|
||||||
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
|
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
|
||||||
|
@ -241,7 +244,8 @@ class XttsArgs(Coqpit):
|
||||||
gpt_num_audio_tokens: int = 8194
|
gpt_num_audio_tokens: int = 8194
|
||||||
gpt_start_audio_token: int = 8192
|
gpt_start_audio_token: int = 8192
|
||||||
gpt_stop_audio_token: int = 8193
|
gpt_stop_audio_token: int = 8193
|
||||||
gpt_use_masking_gt_as_prompt: bool = True
|
gpt_code_stride_len: int = 1024
|
||||||
|
gpt_use_masking_gt_prompt_approach: bool = True
|
||||||
gpt_use_perceiver_resampler: bool = False
|
gpt_use_perceiver_resampler: bool = False
|
||||||
|
|
||||||
# Diffusion Decoder params
|
# Diffusion Decoder params
|
||||||
|
@ -261,7 +265,6 @@ class XttsArgs(Coqpit):
|
||||||
input_sample_rate: int = 22050
|
input_sample_rate: int = 22050
|
||||||
output_sample_rate: int = 24000
|
output_sample_rate: int = 24000
|
||||||
output_hop_length: int = 256
|
output_hop_length: int = 256
|
||||||
ar_mel_length_compression: int = 1024
|
|
||||||
decoder_input_dim: int = 1024
|
decoder_input_dim: int = 1024
|
||||||
d_vector_dim: int = 512
|
d_vector_dim: int = 512
|
||||||
cond_d_vector_in_each_upsampling_layer: bool = True
|
cond_d_vector_in_each_upsampling_layer: bool = True
|
||||||
|
@ -319,6 +322,7 @@ class Xtts(BaseTTS):
|
||||||
start_audio_token=self.args.gpt_start_audio_token,
|
start_audio_token=self.args.gpt_start_audio_token,
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||||
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
||||||
|
code_stride_len=self.args.gpt_code_stride_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_hifigan:
|
if self.args.use_hifigan:
|
||||||
|
@ -326,7 +330,7 @@ class Xtts(BaseTTS):
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
output_sample_rate=self.args.output_sample_rate,
|
||||||
output_hop_length=self.args.output_hop_length,
|
output_hop_length=self.args.output_hop_length,
|
||||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
decoder_input_dim=self.args.decoder_input_dim,
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
d_vector_dim=self.args.d_vector_dim,
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
|
@ -337,7 +341,7 @@ class Xtts(BaseTTS):
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
output_sample_rate=self.args.output_sample_rate,
|
||||||
output_hop_length=self.args.output_hop_length,
|
output_hop_length=self.args.output_hop_length,
|
||||||
ar_mel_length_compression=self.args.ar_mel_length_compression,
|
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
||||||
decoder_input_dim=self.args.decoder_input_dim,
|
decoder_input_dim=self.args.decoder_input_dim,
|
||||||
d_vector_dim=self.args.d_vector_dim,
|
d_vector_dim=self.args.d_vector_dim,
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
|
|
Loading…
Reference in New Issue