Implement binary alignment loss

This commit is contained in:
Eren Gölge 2021-09-03 13:23:22 +00:00
parent 6e9d4062f2
commit debf772ec5
2 changed files with 39 additions and 1 deletions

View File

@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig):
Args: Args:
model (str): model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
model_args (Coqpit): model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int): data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10. for the rest. Defaults to 10.
use_speaker_embedding (bool): use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False. in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool): use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str): d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None. Path to the file including pre-computed speaker embeddings. Defaults to None.
noam_schedule (bool): noam_schedule (bool):
enable / disable the use of Noam LR scheduler. Defaults to False. enable / disable the use of Noam LR scheduler. Defaults to False.
warmup_steps (int): warmup_steps (int):
Number of warm-up steps for the Noam scheduler. Defaults 4000. Number of warm-up steps for the Noam scheduler. Defaults 4000.
lr (float): lr (float):
Initial learning rate. Defaults to `1e-3`. Initial learning rate. Defaults to `1e-3`.
wd (float): wd (float):
Weight decay coefficient. Defaults to `1e-7`. Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float): ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
huber_loss_alpha (float): huber_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float): spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float): pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
min_seq_len (int): min_seq_len (int):
Minimum input sequence length to be used at training. Minimum input sequence length to be used at training.
max_seq_len (int): max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
""" """
@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig):
pitch_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
# overrides # overrides
min_seq_len: int = 13 min_seq_len: int = 13

View File

@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module):
self.dur_loss_alpha = c.dur_loss_alpha self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
return -log_sum / alignment_hard.sum()
def forward( def forward(
self, self,
@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module):
pitch_target, pitch_target,
input_lens, input_lens,
alignment_logprob=None, alignment_logprob=None,
alignment_hard=None,
alignment_soft=None,
): ):
loss = 0 loss = 0
return_dict = {} return_dict = {}
@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module):
if self.aligner_loss_alpha > 0: if self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss += self.aligner_loss_alpha * aligner_loss loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss"] = loss return_dict["loss"] = loss
return return_dict return return_dict