Fix glow tts initialization

This commit is contained in:
Eren Gölge 2021-07-02 10:45:37 +02:00
parent 40b0b5365e
commit 95ad72f38f
1 changed files with 21 additions and 5 deletions

View File

@ -11,6 +11,7 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
@ -50,9 +51,9 @@ class GlowTTS(BaseTTS):
for key in config:
setattr(self, key, config[key])
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
chars, self.config, self.num_chars = self.get_characters(config)
self.decoder_output_dim = config.out_channels
self.init_multispeaker(config)
# if is a multispeaker and c_in_channels is 0, set to 256
@ -91,9 +92,23 @@ class GlowTTS(BaseTTS):
c_in_channels=self.c_in_channels,
)
if self.num_speakers > 1 and not self.d_vector_dim:
# speaker embedding layer
self.emb_g = nn.Embedding(self.num_speakers, self.c_in_channels)
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
If you need a different behaviour, override this function for your model.
Args:
config (Coqpit): Model configuration.
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
"""
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.embedded_speaker_dim = self.c_in_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
@staticmethod
@ -260,6 +275,7 @@ class GlowTTS(BaseTTS):
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
if g is not None:
if self.d_vector_dim:
g = F.normalize(g).unsqueeze(-1)