mirror of https://github.com/coqui-ai/TTS.git
Implement binary alignment loss
This commit is contained in:
parent
6e9d4062f2
commit
debf772ec5
|
@ -17,37 +17,58 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
Args:
|
Args:
|
||||||
model (str):
|
model (str):
|
||||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||||
|
|
||||||
model_args (Coqpit):
|
model_args (Coqpit):
|
||||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
|
||||||
data_dep_init_steps (int):
|
data_dep_init_steps (int):
|
||||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||||
for the rest. Defaults to 10.
|
for the rest. Defaults to 10.
|
||||||
|
|
||||||
use_speaker_embedding (bool):
|
use_speaker_embedding (bool):
|
||||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||||
in the multi-speaker mode. Defaults to False.
|
in the multi-speaker mode. Defaults to False.
|
||||||
|
|
||||||
use_d_vector_file (bool):
|
use_d_vector_file (bool):
|
||||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||||
|
|
||||||
d_vector_file (str):
|
d_vector_file (str):
|
||||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
noam_schedule (bool):
|
noam_schedule (bool):
|
||||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||||
|
|
||||||
warmup_steps (int):
|
warmup_steps (int):
|
||||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||||
|
|
||||||
lr (float):
|
lr (float):
|
||||||
Initial learning rate. Defaults to `1e-3`.
|
Initial learning rate. Defaults to `1e-3`.
|
||||||
|
|
||||||
wd (float):
|
wd (float):
|
||||||
Weight decay coefficient. Defaults to `1e-7`.
|
Weight decay coefficient. Defaults to `1e-7`.
|
||||||
|
|
||||||
ssim_loss_alpha (float):
|
ssim_loss_alpha (float):
|
||||||
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||||
|
|
||||||
huber_loss_alpha (float):
|
huber_loss_alpha (float):
|
||||||
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||||
|
|
||||||
spec_loss_alpha (float):
|
spec_loss_alpha (float):
|
||||||
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||||
|
|
||||||
pitch_loss_alpha (float):
|
pitch_loss_alpha (float):
|
||||||
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_loss_alpha (float):
|
||||||
|
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_align_loss_start_step (int):
|
||||||
|
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum input sequence length to be used at training.
|
Minimum input sequence length to be used at training.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
@ -77,6 +98,8 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
pitch_loss_alpha: float = 1.0
|
pitch_loss_alpha: float = 1.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
aligner_loss_alpha: float = 1.0
|
aligner_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_start_step: int = 20000
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 13
|
min_seq_len: int = 13
|
||||||
|
|
|
@ -705,6 +705,14 @@ class FastPitchLoss(nn.Module):
|
||||||
self.dur_loss_alpha = c.dur_loss_alpha
|
self.dur_loss_alpha = c.dur_loss_alpha
|
||||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||||
|
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||||
|
|
||||||
|
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
|
||||||
|
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||||
|
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||||
|
"""
|
||||||
|
log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum()
|
||||||
|
return -log_sum / alignment_hard.sum()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -717,6 +725,8 @@ class FastPitchLoss(nn.Module):
|
||||||
pitch_target,
|
pitch_target,
|
||||||
input_lens,
|
input_lens,
|
||||||
alignment_logprob=None,
|
alignment_logprob=None,
|
||||||
|
alignment_hard=None,
|
||||||
|
alignment_soft=None,
|
||||||
):
|
):
|
||||||
loss = 0
|
loss = 0
|
||||||
return_dict = {}
|
return_dict = {}
|
||||||
|
@ -743,8 +753,13 @@ class FastPitchLoss(nn.Module):
|
||||||
|
|
||||||
if self.aligner_loss_alpha > 0:
|
if self.aligner_loss_alpha > 0:
|
||||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||||
loss += self.aligner_loss_alpha * aligner_loss
|
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||||
|
|
||||||
|
if self.binary_alignment_loss_alpha > 0:
|
||||||
|
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
|
||||||
|
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss
|
||||||
|
|
||||||
return_dict["loss"] = loss
|
return_dict["loss"] = loss
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
Loading…
Reference in New Issue