mirror of https://github.com/coqui-ai/TTS.git
Fix glow tts initialization
This commit is contained in:
parent
40b0b5365e
commit
95ad72f38f
|
@ -11,6 +11,7 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -50,9 +51,9 @@ class GlowTTS(BaseTTS):
|
|||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
chars, self.config, self.num_chars = self.get_characters(config)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
|
@ -91,9 +92,23 @@ class GlowTTS(BaseTTS):
|
|||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
if self.num_speakers > 1 and not self.d_vector_dim:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.c_in_channels)
|
||||
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = self.c_in_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
|
@ -260,6 +275,7 @@ class GlowTTS(BaseTTS):
|
|||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
|
||||
if g is not None:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
|
Loading…
Reference in New Issue