mirror of https://github.com/coqui-ai/TTS.git
Implement init_speaker_encoder and change arg names
This commit is contained in:
parent
1ddf245b08
commit
95ca2ef773
|
@ -1,5 +1,5 @@
|
|||
import time
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
|
@ -117,6 +117,7 @@ class Synthesizer(object):
|
|||
|
||||
speaker_manager = self._init_speaker_manager()
|
||||
language_manager = self._init_language_manager()
|
||||
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||
|
||||
if language_manager is not None:
|
||||
self.tts_model = setup_tts_model(
|
||||
|
@ -130,23 +131,47 @@ class Synthesizer(object):
|
|||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _is_use_speaker_embedding(self):
|
||||
"""Check if the speaker embedding is used in the model"""
|
||||
# some models use model_args some don't
|
||||
if hasattr(self.tts_config, "model_args"):
|
||||
config = self.tts_config.model_args
|
||||
else:
|
||||
config = self.tts_config
|
||||
return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True
|
||||
|
||||
def _is_use_d_vector_file(self):
|
||||
"""Check if the d-vector file is used in the model"""
|
||||
# some models use model_args some don't
|
||||
if hasattr(self.tts_config, "model_args"):
|
||||
config = self.tts_config.model_args
|
||||
else:
|
||||
config = self.tts_config
|
||||
return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True
|
||||
|
||||
def _init_speaker_manager(self):
|
||||
"""Initialize the SpeakerManager"""
|
||||
# setup if multi-speaker settings are in the global model config
|
||||
speaker_manager = None
|
||||
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
||||
if self._is_use_speaker_embedding():
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("speakers_file", None):
|
||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
||||
|
||||
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
|
||||
if self._is_use_d_vector_file():
|
||||
if self.tts_speakers_file:
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||
if self.tts_config.get("d_vector_file", None):
|
||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||
return speaker_manager
|
||||
|
||||
def _init_speaker_encoder(self, speaker_manager):
|
||||
"""Initialize the SpeakerEncoder"""
|
||||
if self.encoder_checkpoint is not None:
|
||||
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||
return speaker_manager
|
||||
|
||||
def _init_language_manager(self):
|
||||
"""Initialize the LanguageManager"""
|
||||
# setup if multi-lingual settings are in the global model config
|
||||
|
@ -203,7 +228,12 @@ class Synthesizer(object):
|
|||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def tts(
|
||||
self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None
|
||||
self,
|
||||
text: str,
|
||||
speaker_name: str = "",
|
||||
language_name: str = "",
|
||||
speaker_wav: Union[str, List[str]] = None,
|
||||
style_wav=None,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
|
@ -211,7 +241,7 @@ class Synthesizer(object):
|
|||
text (str): input text.
|
||||
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav ():
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
|
||||
Returns:
|
||||
|
|
Loading…
Reference in New Issue