From 0ebc2a400eb4f44a04a6bbcadab649afa08eaae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 20 Oct 2021 18:15:20 +0000 Subject: [PATCH] Implement `_set_speaker_embedding` in GlowTTS --- TTS/tts/models/glow_tts.py | 133 +++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e3a5ff3c..c1e4c2ac 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,17 +1,18 @@ import math -from typing import Dict, Tuple +from typing import Dict, Tuple, Union import torch +from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F -from TTS.tts.configs import GlowTTSConfig +from TTS.tts.configs.glow_tts_config import GlowTTSConfig from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask -from TTS.tts.utils.speakers import get_speaker_manager +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.io import load_fsspec @@ -38,17 +39,19 @@ class GlowTTS(BaseTTS): Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Examples: - >>> from TTS.tts.configs import GlowTTSConfig + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig >>> from TTS.tts.models.glow_tts import GlowTTS >>> config = GlowTTSConfig() >>> model = GlowTTS(config) """ - def __init__(self, config: GlowTTSConfig): + def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None): super().__init__(config) + self.speaker_manager = speaker_manager + # pass all config fields to `self` # for fewer code change self.config = config @@ -58,19 +61,10 @@ class GlowTTS(BaseTTS): _, self.config, self.num_chars = self.get_characters(config) self.decoder_output_dim = config.out_channels + # init multi-speaker layers if necessary self.init_multispeaker(config) - # if is a multispeaker and c_in_channels is 0, set to 256 - self.c_in_channels = 0 - if self.num_speakers > 1: - if self.d_vector_dim: - self.c_in_channels = self.d_vector_dim - elif self.c_in_channels == 0 and not self.d_vector_dim: - # TODO: make this adjustable - self.c_in_channels = 256 - self.run_data_dep_init = config.data_dep_init_steps > 0 - self.encoder = Encoder( self.num_chars, out_channels=self.out_channels, @@ -98,28 +92,35 @@ class GlowTTS(BaseTTS): c_in_channels=self.c_in_channels, ) - def init_multispeaker(self, config: "Coqpit", data: list = None) -> None: - """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer - or with external `d_vectors` computed from a speaker encoder model. - - If you need a different behaviour, override this function for your model. + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension. Args: config (Coqpit): Model configuration. - data (List, optional): Dataset items to infer number of speakers. Defaults to None. """ + self.embedded_speaker_dim = 0 # init speaker manager - self.speaker_manager = get_speaker_manager(config, data=data) - self.num_speakers = self.speaker_manager.num_speakers - if config.use_d_vector_file: - self.external_d_vector_dim = config.d_vector_dim - else: - self.external_d_vector_dim = 0 + if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) # init speaker embedding layer if config.use_speaker_embedding and not config.use_d_vector_file: - self.embedded_speaker_dim = self.c_in_channels - self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + print(" > Init speaker_embedding layer.") + self.embedded_speaker_dim = self.hidden_channels_enc + self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + # set conditioning dimensions + self.c_in_channels = self.embedded_speaker_dim @staticmethod def compute_outputs(attn, o_mean, o_log_scale, x_mask): @@ -146,6 +147,35 @@ class GlowTTS(BaseTTS): if getattr(f, "set_ddi", False): f.set_ddi(False) + def _set_speaker_input(self, aux_input: Dict): + if aux_input is None: + d_vectors = None + speaker_ids = None + else: + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]: + g = self._set_speaker_input(aux_input) + # speaker embedding + if g is not None: + if hasattr(self, "emb_g"): + # use speaker embedding layer + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + else: + # use d-vector + g = F.normalize(g).unsqueeze(-1) # [b, h, 1] + return g + def forward( self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value @@ -161,12 +191,7 @@ class GlowTTS(BaseTTS): y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings - g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if self.use_speaker_embedding or self.use_d_vector_file: - if not self.use_d_vector_file: - g = F.normalize(g).unsqueeze(-1) - else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + g = self._speaker_embedding(aux_input) # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -217,12 +242,7 @@ class GlowTTS(BaseTTS): y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings - g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - if self.use_speaker_embedding or self.use_d_vector_file: - if not self.use_d_vector_file: - g = F.normalize(g).unsqueeze(-1) - else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + g = self._speaker_embedding(aux_input) # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. @@ -272,22 +292,12 @@ class GlowTTS(BaseTTS): """ y = y.transpose(1, 2) y_max_length = y.size(2) - g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - # norm speaker embeddings - if g is not None: - if self.external_d_vector_dim: - g = F.normalize(g).unsqueeze(-1) - else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - + g = self._speaker_embedding(aux_input) y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) - # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) - # reverse decoder and predict y, logdet = self.decoder(z, y_mask, g=g, reverse=True) - outputs = {} outputs["model_outputs"] = y.transpose(1, 2) outputs["logdet"] = logdet @@ -298,14 +308,7 @@ class GlowTTS(BaseTTS): self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] - g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None - - if g is not None: - if self.d_vector_dim: - g = F.normalize(g).unsqueeze(-1) - else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] - + g = self._speaker_embedding(aux_input) # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations @@ -389,15 +392,15 @@ class GlowTTS(BaseTTS): def _create_logs(self, batch, outputs, ap): alignments = outputs["alignments"] - text_input = batch["text_input"] + text_input = batch["text_input"][:1] if batch["text_input"] is not None else None text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] - d_vectors = batch["d_vectors"] - speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None + speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None # model runs reverse flow to predict spectrograms pred_outputs = self.inference( - text_input[:1], + text_input, aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, ) model_outputs = pred_outputs["model_outputs"] @@ -448,7 +451,7 @@ class GlowTTS(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = self.get_aux_input() + aux_inputs = self._get_test_aux_input() if len(test_sentences) == 0: print(" | [!] No test sentences provided.") else: