From c68962c57409f34c8a88a0163f5f476e3b87d2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:53:44 +0100 Subject: [PATCH] Update forward tts binary loss --- TTS/tts/configs/fast_pitch_config.py | 3 +++ TTS/tts/configs/fast_speech_config.py | 6 +++--- TTS/tts/configs/speedy_speech_config.py | 6 +++--- TTS/tts/models/forward_tts.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index de870388..024040f8 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -92,6 +92,9 @@ class FastPitchConfig(BaseTTSConfig): binary_align_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + min_seq_len (int): Minimum input sequence length to be used at training. diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py index 31d99442..f0c23593 100644 --- a/TTS/tts/configs/fast_speech_config.py +++ b/TTS/tts/configs/fast_speech_config.py @@ -93,8 +93,8 @@ class FastSpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -135,7 +135,7 @@ class FastSpeechConfig(BaseTTSConfig): pitch_loss_alpha: float = 0.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0 - binary_align_loss_start_step: int = 20000 + binary_align_loss_start_step: int = 50000 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index ea6866ed..4bf5101f 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -89,8 +89,8 @@ class SpeedySpeechConfig(BaseTTSConfig): binary_loss_alpha (float): Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. - binary_align_loss_start_step (int): - Start binary alignment loss after this many steps. Defaults to 20000. + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. min_seq_len (int): Minimum input sequence length to be used at training. @@ -150,7 +150,7 @@ class SpeedySpeechConfig(BaseTTSConfig): spec_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 0.3 - binary_align_loss_start_step: int = 50000 + binary_loss_warmup_epochs: int = 150 # overrides min_seq_len: int = 13 diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 8d554f76..db8fef2d 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -178,8 +178,8 @@ class ForwardTTS(BaseTTS): tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): - super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) self.init_multispeaker(config)