From debf772ec5128963a0c3bec8e4f6bcaafe981221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 3 Sep 2021 13:23:22 +0000 Subject: [PATCH] Implement binary alignment loss --- TTS/tts/configs/fast_pitch_config.py | 23 +++++++++++++++++++++++ TTS/tts/layers/losses.py | 17 ++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 2c54803a..873f298e 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig): Args: model (str): Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + model_args (Coqpit): Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + data_dep_init_steps (int): Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses Activation Normalization that pre-computes normalization stats at the beginning and use the same values for the rest. Defaults to 10. + use_speaker_embedding (bool): enable / disable using speaker embeddings for multi-speaker models. If set True, the model is in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): Initial learning rate. Defaults to `1e-3`. + wd (float): Weight decay coefficient. Defaults to `1e-7`. + ssim_loss_alpha (float): Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + huber_loss_alpha (float): Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + spec_loss_alpha (float): Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + pitch_loss_alpha (float): Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + min_seq_len (int): Minimum input sequence length to be used at training. + max_seq_len (int): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ @@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig): pitch_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 6ca010dd..805f36d6 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module): self.dur_loss_alpha = c.dur_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha self.aligner_loss_alpha = c.aligner_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + + def _binary_alignment_loss(self, alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() def forward( self, @@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module): pitch_target, input_lens, alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, ): loss = 0 return_dict = {} @@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module): if self.aligner_loss_alpha > 0: aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) - loss += self.aligner_loss_alpha * aligner_loss + loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + if self.binary_alignment_loss_alpha > 0: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + return_dict["loss"] = loss return return_dict