diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index d770a536..f4a472ad 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -740,6 +740,7 @@ class ForwardTTSLoss(nn.Module): alignment_logprob=None, alignment_hard=None, alignment_soft=None, + binary_loss_weight=None ): loss = 0 return_dict = {} @@ -772,7 +773,10 @@ class ForwardTTSLoss(nn.Module): if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss - return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 699f3142..bb8640a3 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram @dataclass @@ -186,7 +186,7 @@ class ForwardTTS(BaseTTS): self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner self.use_pitch = self.args.use_pitch - self.use_binary_alignment_loss = False + self.binary_loss_weight = 0.0 self.length_scale = ( float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale @@ -644,8 +644,9 @@ class ForwardTTS(BaseTTS): pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, input_lens=text_lengths, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, - alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, - alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight ) # compute duration error durations_pred = outputs["durations"] @@ -672,17 +673,12 @@ class ForwardTTS(BaseTTS): # plot pitch figures if self.args.use_pitch: - pitch = batch["pitch"] - pitch_avg_expanded, _ = self.expand_encoder_outputs( - outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] - ) - pitch = pitch[0, 0].data.cpu().numpy() - # TODO: denormalize before plotting - pitch = abs(pitch) - pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) pitch_figures = { - "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), - "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), } figures.update(pitch_figures) @@ -725,9 +721,8 @@ class ForwardTTS(BaseTTS): return ForwardTTSLoss(self.config) def on_train_step_start(self, trainer): - """Enable binary alignment loss when needed""" - if trainer.total_steps_done > self.config.binary_align_loss_start_step: - self.use_binary_alignment_loss = True + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 @staticmethod def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):