From cbbc9e017278155af3dac3dc33d13dda400c18d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 11 Sep 2021 10:20:37 +0000 Subject: [PATCH] Add FastSpeechConfig --- TTS/tts/configs/fast_speech_config.py | 151 ++++++++++++++++++++++++ TTS/tts/configs/speedy_speech_config.py | 5 +- TTS/tts/models/forward_tts.py | 17 ++- 3 files changed, 165 insertions(+), 8 deletions(-) create mode 100644 TTS/tts/configs/fast_speech_config.py diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py new file mode 100644 index 00000000..bba47bb3 --- /dev/null +++ b/TTS/tts/configs/fast_speech_config.py @@ -0,0 +1,151 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastSpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastSpeech model. + + Example: + + >>> from TTS.tts.configs import FastSpeechConfig + >>> config = FastSpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_speech" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False) + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index 1ec8f729..23a96ff1 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -119,6 +119,7 @@ class SpeedySpeechConfig(BaseTTSConfig): hidden_channels=128, num_speakers=0, positional_encoding=True, + detach_duration_predictor=True ) # multi-speaker settings @@ -128,7 +129,7 @@ class SpeedySpeechConfig(BaseTTSConfig): d_vector_dim: int = 0 # optimizer parameters - optimizer: str = "RAdam" + optimizer: str = "Adam" optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) lr_scheduler: str = "NoamLR" lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) @@ -138,7 +139,7 @@ class SpeedySpeechConfig(BaseTTSConfig): # loss params spec_loss_type: str = "l1" duration_loss_type: str = "huber" - use_ssim_loss: bool = True + use_ssim_loss: bool = False ssim_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index c4411027..9dce36fa 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -99,6 +99,7 @@ class ForwardTTSArgs(Coqpit): max_duration (int): Maximum duration accepted by the model. Defaults to 75. + """ num_chars: int = None @@ -264,18 +265,22 @@ class ForwardTTS(BaseTTS): """Generate attention alignment map from durations and expand encoder outputs - Shapes + Shapes: - en: :math:`(B, D_{en}, T_{en})` - dr: :math:`(B, T_{en})` - x_mask: :math:`(B, T_{en})` - y_mask: :math:`(B, T_{de})` - Examples: - - encoder output: :math:`[a,b,c,d]` - - durations: :math:`[1, 3, 2, 1]` + Examples:: - - expanded: :math:`[a, b, b, b, c, c, d]` - - attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]` + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] """ attn = self.generate_attn(dr, x_mask, y_mask) o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)