diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index a634aa6e..8e001630 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -11,6 +11,7 @@ from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -99,9 +100,10 @@ class AlignTTS(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): super().__init__(config) + self.speaker_manager = speaker_manager self.config = config self.phase = -1 self.length_scale = (