From d847a68e42536805f3f301555201c001fcf8f055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 6 Sep 2021 14:27:13 +0000 Subject: [PATCH] Reformat multi-speaker handling in GlowTTS --- TTS/tts/models/glow_tts.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 27012207..b063b6b4 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -109,6 +109,10 @@ class GlowTTS(BaseTTS): # init speaker manager self.speaker_manager = get_speaker_manager(config, data=data) self.num_speakers = self.speaker_manager.num_speakers + if config.use_d_vector_file: + self.external_d_vector_dim = config.d_vector_dim + else: + self.external_d_vector_dim = 0 # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: self.embedded_speaker_dim = self.c_in_channels @@ -129,7 +133,7 @@ class GlowTTS(BaseTTS): return y_mean, y_log_scale, o_attn_dur def forward( - self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -143,8 +147,8 @@ class GlowTTS(BaseTTS): y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] @@ -181,7 +185,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def inference_with_MAS( - self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None} + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. @@ -198,12 +202,11 @@ class GlowTTS(BaseTTS): y_max_length = y.size(2) # norm speaker embeddings g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if g is not None: - if self.external_d_vector_dim: + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -243,7 +246,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def decoder_inference( - self, y, y_lengths=None, aux_input={"d_vectors": None} + self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -275,7 +278,7 @@ class GlowTTS(BaseTTS): return outputs @torch.no_grad() - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value + def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None @@ -326,8 +329,9 @@ class GlowTTS(BaseTTS): mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] - outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors}) + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids}) loss_dict = criterion( outputs["model_outputs"],