mirror of https://github.com/coqui-ai/TTS.git
Implement `_set_speaker_embedding` in GlowTTS
This commit is contained in:
parent
3da79a4de4
commit
0ebc2a400e
|
@ -1,17 +1,18 @@
|
|||
import math
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
@ -38,17 +39,19 @@ class GlowTTS(BaseTTS):
|
|||
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import GlowTTSConfig
|
||||
>>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig()
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: GlowTTSConfig):
|
||||
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
self.config = config
|
||||
|
@ -58,19 +61,10 @@ class GlowTTS(BaseTTS):
|
|||
_, self.config, self.num_chars = self.get_characters(config)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# init multi-speaker layers if necessary
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
self.c_in_channels = 0
|
||||
if self.num_speakers > 1:
|
||||
if self.d_vector_dim:
|
||||
self.c_in_channels = self.d_vector_dim
|
||||
elif self.c_in_channels == 0 and not self.d_vector_dim:
|
||||
# TODO: make this adjustable
|
||||
self.c_in_channels = 256
|
||||
|
||||
self.run_data_dep_init = config.data_dep_init_steps > 0
|
||||
|
||||
self.encoder = Encoder(
|
||||
self.num_chars,
|
||||
out_channels=self.out_channels,
|
||||
|
@ -98,28 +92,35 @@ class GlowTTS(BaseTTS):
|
|||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
self.embedded_speaker_dim = 0
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
if config.use_d_vector_file:
|
||||
self.external_d_vector_dim = config.d_vector_dim
|
||||
else:
|
||||
self.external_d_vector_dim = 0
|
||||
if self.speaker_manager is None and (self.use_speaker_embedding or self.use_d_vector_file):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = self.c_in_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
print(" > Init speaker_embedding layer.")
|
||||
self.embedded_speaker_dim = self.hidden_channels_enc
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
# set conditioning dimensions
|
||||
self.c_in_channels = self.embedded_speaker_dim
|
||||
|
||||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
|
@ -146,6 +147,35 @@ class GlowTTS(BaseTTS):
|
|||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
if aux_input is None:
|
||||
d_vectors = None
|
||||
speaker_ids = None
|
||||
else:
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
speaker_ids = aux_input.get("speaker_ids", None)
|
||||
|
||||
if d_vectors is not None and speaker_ids is not None:
|
||||
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
||||
|
||||
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
||||
|
||||
g = speaker_ids if speaker_ids is not None else d_vectors
|
||||
return g
|
||||
|
||||
def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]:
|
||||
g = self._set_speaker_input(aux_input)
|
||||
# speaker embedding
|
||||
if g is not None:
|
||||
if hasattr(self, "emb_g"):
|
||||
# use speaker embedding layer
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
# use d-vector
|
||||
g = F.normalize(g).unsqueeze(-1) # [b, h, 1]
|
||||
return g
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
|
@ -161,12 +191,7 @@ class GlowTTS(BaseTTS):
|
|||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
g = self._speaker_embedding(aux_input)
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
|
@ -217,12 +242,7 @@ class GlowTTS(BaseTTS):
|
|||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if self.use_speaker_embedding or self.use_d_vector_file:
|
||||
if not self.use_d_vector_file:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
g = self._speaker_embedding(aux_input)
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
|
@ -272,22 +292,12 @@ class GlowTTS(BaseTTS):
|
|||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
g = self._speaker_embedding(aux_input)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
|
||||
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
|
||||
# reverse decoder and predict
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
outputs = {}
|
||||
outputs["model_outputs"] = y.transpose(1, 2)
|
||||
outputs["logdet"] = logdet
|
||||
|
@ -298,14 +308,7 @@ class GlowTTS(BaseTTS):
|
|||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
|
||||
g = self._speaker_embedding(aux_input)
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# compute output durations
|
||||
|
@ -389,15 +392,15 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
alignments = outputs["alignments"]
|
||||
text_input = batch["text_input"]
|
||||
text_input = batch["text_input"][:1] if batch["text_input"] is not None else None
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None
|
||||
speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None
|
||||
|
||||
# model runs reverse flow to predict spectrograms
|
||||
pred_outputs = self.inference(
|
||||
text_input[:1],
|
||||
text_input,
|
||||
aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||
)
|
||||
model_outputs = pred_outputs["model_outputs"]
|
||||
|
@ -448,7 +451,7 @@ class GlowTTS(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
if len(test_sentences) == 0:
|
||||
print(" | [!] No test sentences provided.")
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue