Implement init_speaker_encoder and change arg names

This commit is contained in:
Eren Gölge 2021-12-16 14:57:24 +00:00
parent 1ddf245b08
commit 95ca2ef773
1 changed files with 35 additions and 5 deletions

View File

@ -1,5 +1,5 @@
import time
from typing import List
from typing import List, Union
import numpy as np
import pysbd
@ -117,6 +117,7 @@ class Synthesizer(object):
speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager()
speaker_manager = self._init_speaker_encoder(speaker_manager)
if language_manager is not None:
self.tts_model = setup_tts_model(
@ -130,23 +131,47 @@ class Synthesizer(object):
if use_cuda:
self.tts_model.cuda()
def _is_use_speaker_embedding(self):
"""Check if the speaker embedding is used in the model"""
# some models use model_args some don't
if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args
else:
config = self.tts_config
return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True
def _is_use_d_vector_file(self):
"""Check if the d-vector file is used in the model"""
# some models use model_args some don't
if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args
else:
config = self.tts_config
return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True
def _init_speaker_manager(self):
"""Initialize the SpeakerManager"""
# setup if multi-speaker settings are in the global model config
speaker_manager = None
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
if self._is_use_speaker_embedding():
if self.tts_speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
if self.tts_config.get("speakers_file", None):
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
if self._is_use_d_vector_file():
if self.tts_speakers_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
if self.tts_config.get("d_vector_file", None):
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
return speaker_manager
def _init_speaker_encoder(self, speaker_manager):
"""Initialize the SpeakerEncoder"""
if self.encoder_checkpoint is not None:
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
return speaker_manager
def _init_language_manager(self):
"""Initialize the LanguageManager"""
# setup if multi-lingual settings are in the global model config
@ -203,7 +228,12 @@ class Synthesizer(object):
self.ap.save_wav(wav, path, self.output_sample_rate)
def tts(
self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None
self,
text: str,
speaker_name: str = "",
language_name: str = "",
speaker_wav: Union[str, List[str]] = None,
style_wav=None,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -211,7 +241,7 @@ class Synthesizer(object):
text (str): input text.
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav ():
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
Returns: