Implement `_set_speaker_embedding` in GlowTTS

This commit is contained in:
Eren Gölge 2021-10-20 18:15:20 +00:00
parent 3da79a4de4
commit 0ebc2a400e
1 changed files with 68 additions and 65 deletions

View File

@ -1,17 +1,18 @@
import math import math
from typing import Dict, Tuple from typing import Dict, Tuple, Union
import torch import torch
from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.configs import GlowTTSConfig from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -38,17 +39,19 @@ class GlowTTS(BaseTTS):
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples: Examples:
>>> from TTS.tts.configs import GlowTTSConfig >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS >>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig() >>> config = GlowTTSConfig()
>>> model = GlowTTS(config) >>> model = GlowTTS(config)
""" """
def __init__(self, config: GlowTTSConfig): def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
super().__init__(config) super().__init__(config)
self.speaker_manager = speaker_manager
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
self.config = config self.config = config
@ -58,19 +61,10 @@ class GlowTTS(BaseTTS):
_, self.config, self.num_chars = self.get_characters(config) _, self.config, self.num_chars = self.get_characters(config)
self.decoder_output_dim = config.out_channels self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary
self.init_multispeaker(config) self.init_multispeaker(config)
# if is a multispeaker and c_in_channels is 0, set to 256
self.c_in_channels = 0
if self.num_speakers > 1:
if self.d_vector_dim:
self.c_in_channels = self.d_vector_dim
elif self.c_in_channels == 0 and not self.d_vector_dim:
# TODO: make this adjustable
self.c_in_channels = 256
self.run_data_dep_init = config.data_dep_init_steps > 0 self.run_data_dep_init = config.data_dep_init_steps > 0
self.encoder = Encoder( self.encoder = Encoder(
self.num_chars, self.num_chars,
out_channels=self.out_channels, out_channels=self.out_channels,
@ -98,28 +92,35 @@ class GlowTTS(BaseTTS):
c_in_channels=self.c_in_channels, c_in_channels=self.c_in_channels,
) )
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None: def init_multispeaker(self, config: Coqpit):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
or with external `d_vectors` computed from a speaker encoder model. vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
If you need a different behaviour, override this function for your model.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
""" """
self.embedded_speaker_dim = 0
# init speaker manager # init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data) if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
self.num_speakers = self.speaker_manager.num_speakers raise ValueError(
if config.use_d_vector_file: " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
self.external_d_vector_dim = config.d_vector_dim )
else: # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
self.external_d_vector_dim = 0 if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = self.c_in_channels print(" > Init speaker_embedding layer.")
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.embedded_speaker_dim = self.hidden_channels_enc
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
# set conditioning dimensions
self.c_in_channels = self.embedded_speaker_dim
@staticmethod @staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask): def compute_outputs(attn, o_mean, o_log_scale, x_mask):
@ -146,6 +147,35 @@ class GlowTTS(BaseTTS):
if getattr(f, "set_ddi", False): if getattr(f, "set_ddi", False):
f.set_ddi(False) f.set_ddi(False)
def _set_speaker_input(self, aux_input: Dict):
if aux_input is None:
d_vectors = None
speaker_ids = None
else:
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
if d_vectors is not None and speaker_ids is not None:
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
if speaker_ids is not None and not hasattr(self, "emb_g"):
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
g = speaker_ids if speaker_ids is not None else d_vectors
return g
def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]:
g = self._set_speaker_input(aux_input)
# speaker embedding
if g is not None:
if hasattr(self, "emb_g"):
# use speaker embedding layer
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
else:
# use d-vector
g = F.normalize(g).unsqueeze(-1) # [b, h, 1]
return g
def forward( def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
@ -161,12 +191,7 @@ class GlowTTS(BaseTTS):
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)
# norm speaker embeddings # norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None g = self._speaker_embedding(aux_input)
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
@ -217,12 +242,7 @@ class GlowTTS(BaseTTS):
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)
# norm speaker embeddings # norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None g = self._speaker_embedding(aux_input)
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths. # drop redisual frames wrt num_squeeze and set y_lengths.
@ -272,22 +292,12 @@ class GlowTTS(BaseTTS):
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None g = self._speaker_embedding(aux_input)
# norm speaker embeddings
if g is not None:
if self.external_d_vector_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
# decoder pass # decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False) z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# reverse decoder and predict # reverse decoder and predict
y, logdet = self.decoder(z, y_mask, g=g, reverse=True) y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = {} outputs = {}
outputs["model_outputs"] = y.transpose(1, 2) outputs["model_outputs"] = y.transpose(1, 2)
outputs["logdet"] = logdet outputs["logdet"] = logdet
@ -298,14 +308,7 @@ class GlowTTS(BaseTTS):
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"] x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None g = self._speaker_embedding(aux_input)
if g is not None:
if self.d_vector_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass # embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations # compute output durations
@ -389,15 +392,15 @@ class GlowTTS(BaseTTS):
def _create_logs(self, batch, outputs, ap): def _create_logs(self, batch, outputs, ap):
alignments = outputs["alignments"] alignments = outputs["alignments"]
text_input = batch["text_input"] text_input = batch["text_input"][:1] if batch["text_input"] is not None else None
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"] mel_input = batch["mel_input"]
d_vectors = batch["d_vectors"] d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None
# model runs reverse flow to predict spectrograms # model runs reverse flow to predict spectrograms
pred_outputs = self.inference( pred_outputs = self.inference(
text_input[:1], text_input,
aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids},
) )
model_outputs = pred_outputs["model_outputs"] model_outputs = pred_outputs["model_outputs"]
@ -448,7 +451,7 @@ class GlowTTS(BaseTTS):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input() aux_inputs = self._get_test_aux_input()
if len(test_sentences) == 0: if len(test_sentences) == 0:
print(" | [!] No test sentences provided.") print(" | [!] No test sentences provided.")
else: else: