Update FastPitch don't detach duration network inputs

This commit is contained in:
Eren Gölge 2021-07-24 16:36:52 +00:00
parent ca29033ef4
commit 81c228a2d8
3 changed files with 8 additions and 7 deletions

View File

@ -64,14 +64,14 @@ class FastPitchConfig(BaseTTSConfig):
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.98], "weight_decay": 1e-6})
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 1000.0
grad_clip: float = 5.0
# loss params
ssim_loss_alpha: float = 0.0
ssim_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0

View File

@ -505,7 +505,7 @@ class FastPitch(BaseTTS):
o_en_dr, mask_en_dr = o_en, mask_en
# Predict durations
o_dr_log = self.duration_predictor(o_en_dr.detach(), mask_en_dr)
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# TODO: move this to the dataset
@ -560,6 +560,7 @@ class FastPitch(BaseTTS):
# Predict durations
o_dr_log = self.duration_predictor(o_en, mask_en)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
o_dr = o_dr * self.length_scale
# Pitch over chars
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
@ -606,7 +607,7 @@ class FastPitch(BaseTTS):
mel_input,
mel_lengths,
outputs["durations_log"],
torch.log(1 + durations),
durations,
outputs["pitch"],
outputs["pitch_gt"],
text_lengths,

View File

@ -42,8 +42,8 @@ config = FastPitchConfig(
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=25,
print_eval=True,
print_step=50,
print_eval=False,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config]