mirror of https://github.com/coqui-ai/TTS.git
Update `synthesizer.py`
Fixes and changes for multi-speaker model init and custom symbols made by mode.make_symbols()
This commit is contained in:
parent
232a5abb6a
commit
dd55960732
|
@ -12,7 +12,6 @@ from TTS.tts.utils.speakers import SpeakerManager
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||||
from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
|
@ -103,6 +102,34 @@ class Synthesizer(object):
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
||||||
|
|
||||||
|
def _set_tts_speaker_file(self):
|
||||||
|
"""Set the TTS speaker file used by a multi-speaker model."""
|
||||||
|
# setup if multi-speaker settings are in the global model config
|
||||||
|
if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True:
|
||||||
|
if self.tts_config.use_d_vector_file:
|
||||||
|
self.tts_speakers_file = (
|
||||||
|
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
||||||
|
)
|
||||||
|
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
||||||
|
else:
|
||||||
|
self.tts_speakers_file = (
|
||||||
|
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# setup if multi-speaker settings are in the model args config
|
||||||
|
if (
|
||||||
|
self.tts_speakers_file is None
|
||||||
|
and hasattr(self.tts_config, "model_args")
|
||||||
|
and hasattr(self.tts_config.model_args, "use_speaker_embedding")
|
||||||
|
and self.tts_config.model_args.use_speaker_embedding
|
||||||
|
):
|
||||||
|
_args = self.tts_config.model_args
|
||||||
|
if _args.use_d_vector_file:
|
||||||
|
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"]
|
||||||
|
_args["d_vector_file"] = self.tts_speakers_file
|
||||||
|
else:
|
||||||
|
self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"]
|
||||||
|
|
||||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model.
|
"""Load the TTS model.
|
||||||
|
|
||||||
|
@ -113,29 +140,15 @@ class Synthesizer(object):
|
||||||
"""
|
"""
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
|
|
||||||
global symbols, phonemes
|
|
||||||
self.tts_config = load_config(tts_config_path)
|
self.tts_config = load_config(tts_config_path)
|
||||||
self.use_phonemes = self.tts_config.use_phonemes
|
self.use_phonemes = self.tts_config.use_phonemes
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||||
|
|
||||||
if self.tts_config.has("characters") and self.tts_config.characters:
|
|
||||||
symbols, phonemes = make_symbols(**self.tts_config.characters)
|
|
||||||
|
|
||||||
if self.use_phonemes:
|
|
||||||
self.input_size = len(phonemes)
|
|
||||||
else:
|
|
||||||
self.input_size = len(symbols)
|
|
||||||
|
|
||||||
if self.tts_config.use_speaker_embedding is True:
|
|
||||||
self.tts_speakers_file = (
|
|
||||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
|
||||||
)
|
|
||||||
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
|
||||||
|
|
||||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
self._set_tts_speaker_file()
|
||||||
|
|
||||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
"""Load the vocoder model.
|
"""Load the vocoder model.
|
||||||
|
@ -187,15 +200,22 @@ class Synthesizer(object):
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
wavs = []
|
wavs = []
|
||||||
speaker_embedding = None
|
|
||||||
sens = self.split_into_sentences(text)
|
sens = self.split_into_sentences(text)
|
||||||
print(" > Text splitted to sentences.")
|
print(" > Text splitted to sentences.")
|
||||||
print(sens)
|
print(sens)
|
||||||
|
|
||||||
|
# handle multi-speaker
|
||||||
|
speaker_embedding = None
|
||||||
|
speaker_id = None
|
||||||
if self.tts_speakers_file:
|
if self.tts_speakers_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
|
||||||
if speaker_idx and isinstance(speaker_idx, str):
|
if speaker_idx and isinstance(speaker_idx, str):
|
||||||
|
if self.tts_config.use_d_vector_file:
|
||||||
|
# get the speaker embedding from the saved d_vectors.
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||||
|
else:
|
||||||
|
# get speaker idx from the speaker name
|
||||||
|
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
||||||
|
|
||||||
elif not speaker_idx and not speaker_wav:
|
elif not speaker_idx and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-speaker model. "
|
" [!] Look like you use a multi-speaker model. "
|
||||||
|
@ -224,14 +244,14 @@ class Synthesizer(object):
|
||||||
CONFIG=self.tts_config,
|
CONFIG=self.tts_config,
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
speaker_id=None,
|
speaker_id=speaker_id,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
)
|
)
|
||||||
waveform = outputs["wav"]
|
waveform = outputs["wav"]
|
||||||
mel_postnet_spec = outputs["model_outputs"]
|
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().numpy()
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||||
|
|
Loading…
Reference in New Issue