refactoring to allow defining the speaker file externally

This commit is contained in:
Eren Gölge 2021-04-16 15:45:16 +02:00
parent 83aa415934
commit 25328aad00
1 changed files with 8 additions and 6 deletions

View File

@ -22,6 +22,7 @@ class Synthesizer(object):
self,
tts_checkpoint: str,
tts_config_path: str,
tts_speakers_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
use_cuda: bool = False,
@ -44,6 +45,7 @@ class Synthesizer(object):
"""
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.use_cuda = use_cuda
@ -67,9 +69,9 @@ class Synthesizer(object):
return pysbd.Segmenter(language=lang, clean=True)
def _load_speakers(self) -> None:
def _load_speakers(self, speaker_file: str) -> None:
print("Loading speakers ...")
self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file)
self.tts_speakers = load_speaker_mapping(speaker_file)
self.num_speakers = len(self.tts_speakers)
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
"embedding"
@ -79,12 +81,12 @@ class Synthesizer(object):
speaker_embedding = None
if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key:
raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
if not speaker_json_key:
raise ValueError(" [!] While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
if speaker_json_key != "":
assert self.tts_speakers
assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'"
assert speaker_json_key in self.tts_speakers, f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
return speaker_embedding
@ -109,7 +111,7 @@ class Synthesizer(object):
self.input_size = len(symbols)
if self.tts_config.use_speaker_embedding is True:
self._load_speakers()
self._load_speakers(self.tts_config.get('external_speaker_embedding_file', self.tts_speakers_file))
self.tts_model = setup_model(
self.input_size,