mirror of https://github.com/coqui-ai/TTS.git
Update FastPitch don't detach duration network inputs
This commit is contained in:
parent
ca29033ef4
commit
81c228a2d8
|
@ -64,14 +64,14 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
|
||||
# optimizer parameters
|
||||
optimizer: str = "Adam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.98], "weight_decay": 1e-6})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
lr: float = 1e-4
|
||||
grad_clip: float = 1000.0
|
||||
grad_clip: float = 5.0
|
||||
|
||||
# loss params
|
||||
ssim_loss_alpha: float = 0.0
|
||||
ssim_loss_alpha: float = 1.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
spec_loss_alpha: float = 1.0
|
||||
pitch_loss_alpha: float = 1.0
|
||||
|
|
|
@ -505,7 +505,7 @@ class FastPitch(BaseTTS):
|
|||
o_en_dr, mask_en_dr = o_en, mask_en
|
||||
|
||||
# Predict durations
|
||||
o_dr_log = self.duration_predictor(o_en_dr.detach(), mask_en_dr)
|
||||
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
|
||||
# TODO: move this to the dataset
|
||||
|
@ -560,6 +560,7 @@ class FastPitch(BaseTTS):
|
|||
# Predict durations
|
||||
o_dr_log = self.duration_predictor(o_en, mask_en)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
o_dr = o_dr * self.length_scale
|
||||
|
||||
# Pitch over chars
|
||||
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
|
||||
|
@ -606,7 +607,7 @@ class FastPitch(BaseTTS):
|
|||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
torch.log(1 + durations),
|
||||
durations,
|
||||
outputs["pitch"],
|
||||
outputs["pitch_gt"],
|
||||
text_lengths,
|
||||
|
|
|
@ -42,8 +42,8 @@ config = FastPitchConfig(
|
|||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=25,
|
||||
print_eval=True,
|
||||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
|
|
Loading…
Reference in New Issue