Implement binary alignment loss

This commit is contained in:
Eren Gölge 2021-09-03 13:23:22 +00:00
parent 6e9d4062f2
commit debf772ec5
2 changed files with 39 additions and 1 deletions

View File

@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig):
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False.
warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000.
lr (float):
Initial learning rate. Defaults to `1e-3`.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
huber_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig):
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
# overrides
min_seq_len: int = 13

View File

@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module):
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
def forward(
self,
@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module):
pitch_target,
input_lens,
alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
):
loss = 0
return_dict = {}
@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module):
if self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss += self.aligner_loss_alpha * aligner_loss
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss
return return_dict