mirror of https://github.com/coqui-ai/TTS.git
Implement binary alignment loss
This commit is contained in:
parent
6e9d4062f2
commit
debf772ec5
|
@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||
|
||||
model_args (Coqpit):
|
||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||
|
||||
data_dep_init_steps (int):
|
||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||
for the rest. Defaults to 10.
|
||||
|
||||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
|
||||
warmup_steps (int):
|
||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to `1e-3`.
|
||||
|
||||
wd (float):
|
||||
Weight decay coefficient. Defaults to `1e-7`.
|
||||
|
||||
ssim_loss_alpha (float):
|
||||
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||
|
||||
huber_loss_alpha (float):
|
||||
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||
|
||||
spec_loss_alpha (float):
|
||||
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||
|
||||
pitch_loss_alpha (float):
|
||||
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||
|
||||
binary_loss_alpha (float):
|
||||
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||
|
||||
binary_align_loss_start_step (int):
|
||||
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
"""
|
||||
|
@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
pitch_loss_alpha: float = 1.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
binary_align_loss_alpha: float = 1.0
|
||||
binary_align_loss_start_step: int = 20000
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
|
|
|
@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module):
|
|||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||
|
||||
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||
return -log_sum / alignment_hard.sum()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module):
|
|||
pitch_target,
|
||||
input_lens,
|
||||
alignment_logprob=None,
|
||||
alignment_hard=None,
|
||||
alignment_soft=None,
|
||||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module):
|
|||
|
||||
if self.aligner_loss_alpha > 0:
|
||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||
loss += self.aligner_loss_alpha * aligner_loss
|
||||
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||
|
||||
if self.binary_alignment_loss_alpha > 0:
|
||||
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
Loading…
Reference in New Issue