Implement `_set_speaker_embedding` in GlowTTS

This commit is contained in:
Eren Gölge 2021-10-20 18:15:20 +00:00
parent 3da79a4de4
commit 0ebc2a400e
1 changed files with 68 additions and 65 deletions

View File

@ -1,17 +1,18 @@
import math
from typing import Dict, Tuple
from typing import Dict, Tuple, Union
import torch
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.io import load_fsspec
@ -38,17 +39,19 @@ class GlowTTS(BaseTTS):
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import GlowTTSConfig
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS(config)
"""
def __init__(self, config: GlowTTSConfig):
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager
# pass all config fields to `self`
# for fewer code change
self.config = config
@ -58,19 +61,10 @@ class GlowTTS(BaseTTS):
_, self.config, self.num_chars = self.get_characters(config)
self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary
self.init_multispeaker(config)
# if is a multispeaker and c_in_channels is 0, set to 256
self.c_in_channels = 0
if self.num_speakers > 1:
if self.d_vector_dim:
self.c_in_channels = self.d_vector_dim
elif self.c_in_channels == 0 and not self.d_vector_dim:
# TODO: make this adjustable
self.c_in_channels = 256
self.run_data_dep_init = config.data_dep_init_steps > 0
self.encoder = Encoder(
self.num_chars,
out_channels=self.out_channels,
@ -98,28 +92,35 @@ class GlowTTS(BaseTTS):
c_in_channels=self.c_in_channels,
)
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
self.embedded_speaker_dim = 0
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
if config.use_d_vector_file:
self.external_d_vector_dim = config.d_vector_dim
else:
self.external_d_vector_dim = 0
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = self.c_in_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
print(" > Init speaker_embedding layer.")
self.embedded_speaker_dim = self.hidden_channels_enc
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
# set conditioning dimensions
self.c_in_channels = self.embedded_speaker_dim
@staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
@ -146,6 +147,35 @@ class GlowTTS(BaseTTS):
if getattr(f, "set_ddi", False):
f.set_ddi(False)
def _set_speaker_input(self, aux_input: Dict):
if aux_input is None:
d_vectors = None
speaker_ids = None
else:
d_vectors = aux_input.get("d_vectors", None)
speaker_ids = aux_input.get("speaker_ids", None)
if d_vectors is not None and speaker_ids is not None:
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
if speaker_ids is not None and not hasattr(self, "emb_g"):
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
g = speaker_ids if speaker_ids is not None else d_vectors
return g
def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]:
g = self._set_speaker_input(aux_input)
# speaker embedding
if g is not None:
if hasattr(self, "emb_g"):
# use speaker embedding layer
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
else:
# use d-vector
g = F.normalize(g).unsqueeze(-1) # [b, h, 1]
return g
def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
@ -161,12 +191,7 @@ class GlowTTS(BaseTTS):
y = y.transpose(1, 2)
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
g = self._speaker_embedding(aux_input)
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
@ -217,12 +242,7 @@ class GlowTTS(BaseTTS):
y = y.transpose(1, 2)
y_max_length = y.size(2)
# norm speaker embeddings
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if self.use_speaker_embedding or self.use_d_vector_file:
if not self.use_d_vector_file:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
g = self._speaker_embedding(aux_input)
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
@ -272,22 +292,12 @@ class GlowTTS(BaseTTS):
"""
y = y.transpose(1, 2)
y_max_length = y.size(2)
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
# norm speaker embeddings
if g is not None:
if self.external_d_vector_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
g = self._speaker_embedding(aux_input)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# reverse decoder and predict
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = {}
outputs["model_outputs"] = y.transpose(1, 2)
outputs["logdet"] = logdet
@ -298,14 +308,7 @@ class GlowTTS(BaseTTS):
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.d_vector_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
g = self._speaker_embedding(aux_input)
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations
@ -389,15 +392,15 @@ class GlowTTS(BaseTTS):
def _create_logs(self, batch, outputs, ap):
alignments = outputs["alignments"]
text_input = batch["text_input"]
text_input = batch["text_input"][:1] if batch["text_input"] is not None else None
text_lengths = batch["text_lengths"]
mel_input = batch["mel_input"]
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None
speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None
# model runs reverse flow to predict spectrograms
pred_outputs = self.inference(
text_input[:1],
text_input,
aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
model_outputs = pred_outputs["model_outputs"]
@ -448,7 +451,7 @@ class GlowTTS(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
aux_inputs = self._get_test_aux_input()
if len(test_sentences) == 0:
print(" | [!] No test sentences provided.")
else: