Add support for multi-lingual models in CLI

This commit is contained in:
WeberJulian 2021-12-08 19:34:36 +01:00 committed by Eren Gölge
parent 7b81c16434
commit 4706583452
3 changed files with 81 additions and 4 deletions

View File

@ -148,12 +148,19 @@ def main():
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
parser.add_argument(
"--speaker_idx",
type=str,
help="Target speaker ID for a multi-speaker TTS model.",
default=None,
)
parser.add_argument(
"--language_idx",
type=str,
help="Target language ID for a multi-lingual TTS model.",
default=None,
)
parser.add_argument(
"--speaker_wav",
nargs="+",
@ -169,6 +176,14 @@ def main():
const=True,
default=False,
)
parser.add_argument(
"--list_language_idxs",
help="List available language ids for the defined multi-lingual model.",
type=str2bool,
nargs="?",
const=True,
default=False,
)
# aux args
parser.add_argument(
"--save_spectogram",
@ -180,7 +195,7 @@ def main():
args = parser.parse_args()
# print the description if either text or list_models is not set
if args.text is None and not args.list_models and not args.list_speaker_idxs:
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
parser.parse_args(["-h"])
# load model manager
@ -190,6 +205,7 @@ def main():
model_path = None
config_path = None
speakers_file_path = None
language_ids_file_path = None
vocoder_path = None
vocoder_config_path = None
encoder_path = None
@ -213,6 +229,7 @@ def main():
model_path = args.model_path
config_path = args.config_path
speakers_file_path = args.speakers_file_path
language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
@ -227,6 +244,7 @@ def main():
model_path,
config_path,
speakers_file_path,
language_ids_file_path,
vocoder_path,
vocoder_config_path,
encoder_path,
@ -242,6 +260,14 @@ def main():
print(synthesizer.tts_model.speaker_manager.speaker_ids)
return
# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.language_id_mapping)
return
# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
print(
@ -254,7 +280,7 @@ def main():
print(" > Text: {}".format(args.text))
# kick it
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -31,6 +31,7 @@ class LanguageManager:
language_ids_file_path: str = "",
config: Coqpit = None,
):
self.language_id_mapping = {}
if language_ids_file_path:
self.set_language_ids_from_file(language_ids_file_path)

View File

@ -8,6 +8,7 @@ import torch
from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
@ -23,6 +24,7 @@ class Synthesizer(object):
tts_checkpoint: str,
tts_config_path: str,
tts_speakers_file: str = "",
tts_languages_file: str = "",
vocoder_checkpoint: str = "",
vocoder_config: str = "",
encoder_checkpoint: str = "",
@ -52,6 +54,7 @@ class Synthesizer(object):
self.tts_checkpoint = tts_checkpoint
self.tts_config_path = tts_config_path
self.tts_speakers_file = tts_speakers_file
self.tts_languages_file = tts_languages_file
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.encoder_checkpoint = encoder_checkpoint
@ -63,6 +66,9 @@ class Synthesizer(object):
self.speaker_manager = None
self.num_speakers = 0
self.tts_speakers = {}
self.language_manager = None
self.num_languages = 0
self.tts_languages = {}
self.d_vector_dim = 0
self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda
@ -110,8 +116,13 @@ class Synthesizer(object):
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager()
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
self.tts_model = setup_tts_model(
config=self.tts_config,
speaker_manager=speaker_manager,
language_manager=language_manager,
)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda:
self.tts_model.cuda()
@ -133,6 +144,17 @@ class Synthesizer(object):
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
return speaker_manager
def _init_language_manager(self):
"""Initialize the LanguageManager"""
# setup if multi-lingual settings are in the global model config
language_manager = None
if hasattr(self.tts_config, "use_language_embedding") and self.tts_config.use_language_embedding is True:
if self.tts_languages_file:
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
elif self.tts_config.get("language_ids_file", None):
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
return language_manager
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
"""Load the vocoder model.
@ -174,12 +196,20 @@ class Synthesizer(object):
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
def tts(
self,
text: str,
speaker_idx: str = "",
language_idx: str = "",
speaker_wav=None,
style_wav=None
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Args:
text (str): input text.
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
language_idx (str, optional): language id for multi-language models. Defaults to "".
speaker_wav ():
style_wav ([type], optional): style waveform for GST. Defaults to None.
@ -219,6 +249,24 @@ class Synthesizer(object):
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
)
# handle multi-lingaul
language_id = None
if self.tts_languages_file or hasattr(self.tts_model.language_manager, "language_id_mapping"):
if language_idx and isinstance(language_idx, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
elif not language_idx:
raise ValueError(
" [!] Look like you use a multi-lingual model. "
"You need to define either a `language_idx` or a `style_wav` to use a multi-lingual model."
)
else:
raise ValueError(
f" [!] Missing language_ids.json file path for selecting language {language_idx}."
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
)
# compute a new d_vector from the given clip.
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
@ -234,6 +282,8 @@ class Synthesizer(object):
use_cuda=self.use_cuda,
ap=self.ap,
speaker_id=speaker_id,
language_id=language_id,
language_name=language_idx,
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,