mirror of https://github.com/coqui-ai/TTS.git
Implement init_speaker_encoder and change arg names
This commit is contained in:
parent
1ddf245b08
commit
95ca2ef773
|
@ -1,5 +1,5 @@
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
|
@ -117,6 +117,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
speaker_manager = self._init_speaker_manager()
|
speaker_manager = self._init_speaker_manager()
|
||||||
language_manager = self._init_language_manager()
|
language_manager = self._init_language_manager()
|
||||||
|
speaker_manager = self._init_speaker_encoder(speaker_manager)
|
||||||
|
|
||||||
if language_manager is not None:
|
if language_manager is not None:
|
||||||
self.tts_model = setup_tts_model(
|
self.tts_model = setup_tts_model(
|
||||||
|
@ -130,23 +131,47 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
|
def _is_use_speaker_embedding(self):
|
||||||
|
"""Check if the speaker embedding is used in the model"""
|
||||||
|
# some models use model_args some don't
|
||||||
|
if hasattr(self.tts_config, "model_args"):
|
||||||
|
config = self.tts_config.model_args
|
||||||
|
else:
|
||||||
|
config = self.tts_config
|
||||||
|
return hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding is True
|
||||||
|
|
||||||
|
def _is_use_d_vector_file(self):
|
||||||
|
"""Check if the d-vector file is used in the model"""
|
||||||
|
# some models use model_args some don't
|
||||||
|
if hasattr(self.tts_config, "model_args"):
|
||||||
|
config = self.tts_config.model_args
|
||||||
|
else:
|
||||||
|
config = self.tts_config
|
||||||
|
return hasattr(config, "use_d_vector_file") and config.use_d_vector_file is True
|
||||||
|
|
||||||
def _init_speaker_manager(self):
|
def _init_speaker_manager(self):
|
||||||
"""Initialize the SpeakerManager"""
|
"""Initialize the SpeakerManager"""
|
||||||
# setup if multi-speaker settings are in the global model config
|
# setup if multi-speaker settings are in the global model config
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
if self._is_use_speaker_embedding():
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("speakers_file", None):
|
if self.tts_config.get("speakers_file", None):
|
||||||
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file)
|
||||||
|
|
||||||
if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True:
|
if self._is_use_d_vector_file():
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file)
|
||||||
if self.tts_config.get("d_vector_file", None):
|
if self.tts_config.get("d_vector_file", None):
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
def _init_speaker_encoder(self, speaker_manager):
|
||||||
|
"""Initialize the SpeakerEncoder"""
|
||||||
|
if self.encoder_checkpoint is not None:
|
||||||
|
speaker_manager.init_speaker_encoder(self.encoder_checkpoint, self.encoder_config)
|
||||||
|
return speaker_manager
|
||||||
|
|
||||||
def _init_language_manager(self):
|
def _init_language_manager(self):
|
||||||
"""Initialize the LanguageManager"""
|
"""Initialize the LanguageManager"""
|
||||||
# setup if multi-lingual settings are in the global model config
|
# setup if multi-lingual settings are in the global model config
|
||||||
|
@ -203,7 +228,12 @@ class Synthesizer(object):
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None
|
self,
|
||||||
|
text: str,
|
||||||
|
speaker_name: str = "",
|
||||||
|
language_name: str = "",
|
||||||
|
speaker_wav: Union[str, List[str]] = None,
|
||||||
|
style_wav=None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
|
@ -211,7 +241,7 @@ class Synthesizer(object):
|
||||||
text (str): input text.
|
text (str): input text.
|
||||||
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||||
speaker_wav ():
|
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
Loading…
Reference in New Issue