From 95ad72f38feb7d839c5bfe51409a713a78565edf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 2 Jul 2021 10:45:37 +0200 Subject: [PATCH] Fix glow tts initialization --- TTS/tts/models/glow_tts.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 5f966c2c..d7406c73 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -11,6 +11,7 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor @@ -50,9 +51,9 @@ class GlowTTS(BaseTTS): for key in config: setattr(self, key, config[key]) - chars, self.config = self.get_characters(config) - self.num_chars = len(chars) + chars, self.config, self.num_chars = self.get_characters(config) self.decoder_output_dim = config.out_channels + self.init_multispeaker(config) # if is a multispeaker and c_in_channels is 0, set to 256 @@ -91,9 +92,23 @@ class GlowTTS(BaseTTS): c_in_channels=self.c_in_channels, ) - if self.num_speakers > 1 and not self.d_vector_dim: - # speaker embedding layer - self.emb_g = nn.Embedding(self.num_speakers, self.c_in_channels) + def init_multispeaker(self, config: "Coqpit", data: list = None) -> None: + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + If you need a different behaviour, override this function for your model. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + # init speaker manager + self.speaker_manager = get_speaker_manager(config, data=data) + self.num_speakers = self.speaker_manager.num_speakers + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + self.embedded_speaker_dim = self.c_in_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod @@ -260,6 +275,7 @@ class GlowTTS(BaseTTS): def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None + if g is not None: if self.d_vector_dim: g = F.normalize(g).unsqueeze(-1)