Implement init_speaker_encoder and change arg names

This commit is contained in:
Eren Gölge 2021-12-16 14:57:24 +00:00
parent 1ddf245b08
commit 95ca2ef773
1 changed files with 35 additions and 5 deletions

View File

@ -1,5 +1,5 @@
import time import time
from typing import List from typing import List, Union
import numpy as np import numpy as np
import pysbd import pysbd
@ -117,6 +117,7 @@ class Synthesizer(object):
speaker_manager = self._init_speaker_manager() speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager() language_manager = self._init_language_manager()
speaker_manager = self._init_speaker_encoder(speaker_manager)
if language_manager is not None: if language_manager is not None:
self.tts_model = setup_tts_model( self.tts_model = setup_tts_model(
@ -130,23 +131,47 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
def _is_use_speaker_embedding(self):
"""Check if the speaker embedding is used in the model"""
# some models use model_args some don't
if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args
else:
config = self.tts_config
return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True
def _is_use_d_vector_file(self):
"""Check if the d-vector file is used in the model"""
# some models use model_args some don't
if hasattr(self.tts_config, "model_args"):
config = self.tts_config.model_args
else:
config = self.tts_config
return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True
def _init_speaker_manager(self): def _init_speaker_manager(self):
"""Initialize the SpeakerManager""" """Initialize the SpeakerManager"""
# setup if multi-speaker settings are in the global model config # setup if multi-speaker settings are in the global model config
speaker_manager = None speaker_manager = None
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True: if self._is_use_speaker_embedding():
if self.tts_speakers_file: if self.tts_speakers_file:
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
if self.tts_config.get("speakers_file", None): if self.tts_config.get("speakers_file", None):
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True: if self._is_use_d_vector_file():
if self.tts_speakers_file: if self.tts_speakers_file:
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
if self.tts_config.get("d_vector_file", None): if self.tts_config.get("d_vector_file", None):
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
return speaker_manager return speaker_manager
def _init_speaker_encoder(self, speaker_manager):
"""Initialize the SpeakerEncoder"""
if self.encoder_checkpoint is not None:
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
return speaker_manager
def _init_language_manager(self): def _init_language_manager(self):
"""Initialize the LanguageManager""" """Initialize the LanguageManager"""
# setup if multi-lingual settings are in the global model config # setup if multi-lingual settings are in the global model config
@ -203,7 +228,12 @@ class Synthesizer(object):
self.ap.save_wav(wav, path, self.output_sample_rate) self.ap.save_wav(wav, path, self.output_sample_rate)
def tts( def tts(
self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None self,
text: str,
speaker_name: str = "",
language_name: str = "",
speaker_wav: Union[str, List[str]] = None,
style_wav=None,
) -> List[int]: ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech. """🐸 TTS magic. Run all the models and generate speech.
@ -211,7 +241,7 @@ class Synthesizer(object):
text (str): input text. text (str): input text.
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_name (str, optional): language id for multi-language models. Defaults to "". language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (): speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None. style_wav ([type], optional): style waveform for GST. Defaults to None.
Returns: Returns: