diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index d02e54c9..7e294896 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -64,14 +64,14 @@ class FastPitchConfig(BaseTTSConfig): # optimizer parameters optimizer: str = "Adam" - optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.98], "weight_decay": 1e-6}) + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) lr_scheduler: str = "NoamLR" lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) lr: float = 1e-4 - grad_clip: float = 1000.0 + grad_clip: float = 5.0 # loss params - ssim_loss_alpha: float = 0.0 + ssim_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 989866ae..60a1654a 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -505,7 +505,7 @@ class FastPitch(BaseTTS): o_en_dr, mask_en_dr = o_en, mask_en # Predict durations - o_dr_log = self.duration_predictor(o_en_dr.detach(), mask_en_dr) + o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # TODO: move this to the dataset @@ -560,6 +560,7 @@ class FastPitch(BaseTTS): # Predict durations o_dr_log = self.duration_predictor(o_en, mask_en) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + o_dr = o_dr * self.length_scale # Pitch over chars o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) @@ -606,7 +607,7 @@ class FastPitch(BaseTTS): mel_input, mel_lengths, outputs["durations_log"], - torch.log(1 + durations), + durations, outputs["pitch"], outputs["pitch_gt"], text_lengths, diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 5bc5f448..4b852d12 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -42,8 +42,8 @@ config = FastPitchConfig( use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), - print_step=25, - print_eval=True, + print_step=50, + print_eval=False, mixed_precision=False, output_path=output_path, datasets=[dataset_config]