From 5d59100a883def5563397a78bdc60a3938cb8f72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 22 Jul 2021 14:21:49 +0200 Subject: [PATCH] Don't use align_score for models with duration predictor --- TTS/tts/models/align_tts.py | 4 ---- TTS/tts/models/glow_tts.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 2aa84cb2..2c3bed3d 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -13,7 +13,6 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec @@ -355,9 +354,6 @@ class AlignTTS(BaseTTS): phase=self.phase, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log( diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 92c42fa7..e6541871 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -10,7 +10,6 @@ from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -341,9 +340,6 @@ class GlowTTS(BaseTTS): text_lengths, ) - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) - loss_dict["align_error"] = align_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use