Update forward tts binary loss

This commit is contained in:
Eren Gölge 2022-02-20 11:53:44 +01:00
parent c11944022d
commit c68962c574
4 changed files with 10 additions and 7 deletions

View File

@ -92,6 +92,9 @@ class FastPitchConfig(BaseTTSConfig):
binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.

View File

@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig):
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig):
pitch_loss_alpha: float = 0.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 1.0
binary_align_loss_start_step: int = 20000
binary_align_loss_start_step: int = 50000
# overrides
min_seq_len: int = 13

View File

@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
binary_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_align_loss_start_step (int):
Start binary alignment loss after this many steps. Defaults to 20000.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig):
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
binary_align_loss_alpha: float = 0.3
binary_align_loss_start_step: int = 50000
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13

View File

@ -178,8 +178,8 @@ class ForwardTTS(BaseTTS):
tokenizer: "TTSTokenizer" = None,
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
self._set_model_args(config)
self.init_multispeaker(config)