From a564eb9f5420cb8607bf43787f7ff77753f631a5 Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 8 Dec 2021 19:34:36 +0100 Subject: [PATCH] Add support for multi-lingual models in CLI --- TTS/bin/synthesize.py | 30 +++++++++++++++++++-- TTS/tts/utils/languages.py | 1 + TTS/utils/synthesizer.py | 54 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 81 insertions(+), 4 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index bf7de798..509b3da6 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -152,12 +152,19 @@ If you don't specify any models, then it uses LJSpeech based English model. # args for multi-speaker synthesis parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) parser.add_argument( "--speaker_idx", type=str, help="Target speaker ID for a multi-speaker TTS model.", default=None, ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) parser.add_argument( "--speaker_wav", nargs="+", @@ -173,6 +180,14 @@ If you don't specify any models, then it uses LJSpeech based English model. const=True, default=False, ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) # aux args parser.add_argument( "--save_spectogram", @@ -184,7 +199,7 @@ If you don't specify any models, then it uses LJSpeech based English model. args = parser.parse_args() # print the description if either text or list_models is not set - if args.text is None and not args.list_models and not args.list_speaker_idxs: + if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs: parser.parse_args(["-h"]) # load model manager @@ -194,6 +209,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = None config_path = None speakers_file_path = None + language_ids_file_path = None vocoder_path = None vocoder_config_path = None encoder_path = None @@ -217,6 +233,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path = args.model_path config_path = args.config_path speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path if args.vocoder_path is not None: vocoder_path = args.vocoder_path @@ -231,6 +248,7 @@ If you don't specify any models, then it uses LJSpeech based English model. model_path, config_path, speakers_file_path, + language_ids_file_path, vocoder_path, vocoder_config_path, encoder_path, @@ -246,6 +264,14 @@ If you don't specify any models, then it uses LJSpeech based English model. print(synthesizer.tts_model.speaker_manager.speaker_ids) return + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + print( + " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + print(synthesizer.tts_model.language_manager.language_id_mapping) + return + # check the arguments against a multi-speaker model. if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): print( @@ -258,7 +284,7 @@ If you don't specify any models, then it uses LJSpeech based English model. print(" > Text: {}".format(args.text)) # kick it - wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style) + wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 451b10f9..fc7eec57 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -31,6 +31,7 @@ class LanguageManager: language_ids_file_path: str = "", config: Coqpit = None, ): + self.language_id_mapping = {} if language_ids_file_path: self.set_language_ids_from_file(language_ids_file_path) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 043c4982..ea8ce6d1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -8,6 +8,7 @@ import torch from TTS.config import load_config from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.languages import LanguageManager # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -23,6 +24,7 @@ class Synthesizer(object): tts_checkpoint: str, tts_config_path: str, tts_speakers_file: str = "", + tts_languages_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", encoder_checkpoint: str = "", @@ -52,6 +54,7 @@ class Synthesizer(object): self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config self.encoder_checkpoint = encoder_checkpoint @@ -63,6 +66,9 @@ class Synthesizer(object): self.speaker_manager = None self.num_speakers = 0 self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} self.d_vector_dim = 0 self.seg = self._get_segmenter("en") self.use_cuda = use_cuda @@ -110,8 +116,13 @@ class Synthesizer(object): self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) speaker_manager = self._init_speaker_manager() + language_manager = self._init_language_manager() - self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager) + self.tts_model = setup_tts_model( + config=self.tts_config, + speaker_manager=speaker_manager, + language_manager=language_manager, + ) self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() @@ -133,6 +144,17 @@ class Synthesizer(object): speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) return speaker_manager + def _init_language_manager(self): + """Initialize the LanguageManager""" + # setup if multi-lingual settings are in the global model config + language_manager = None + if hasattr(self.tts_config, "use_language_embedding") and self.tts_config.use_language_embedding is True: + if self.tts_languages_file: + language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file) + elif self.tts_config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file) + return language_manager + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. @@ -174,12 +196,20 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path, self.output_sample_rate) - def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]: + def tts( + self, + text: str, + speaker_idx: str = "", + language_idx: str = "", + speaker_wav=None, + style_wav=None + ) -> List[int]: """🐸 TTS magic. Run all the models and generate speech. Args: text (str): input text. speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "". + language_idx (str, optional): language id for multi-language models. Defaults to "". speaker_wav (): style_wav ([type], optional): style waveform for GST. Defaults to None. @@ -219,6 +249,24 @@ class Synthesizer(object): "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " ) + # handle multi-lingaul + language_id = None + if self.tts_languages_file or hasattr(self.tts_model.language_manager, "language_id_mapping"): + if language_idx and isinstance(language_idx, str): + language_id = self.tts_model.language_manager.language_id_mapping[language_idx] + + elif not language_idx: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_idx` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_idx}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + # compute a new d_vector from the given clip. if speaker_wav is not None: speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav) @@ -234,6 +282,8 @@ class Synthesizer(object): use_cuda=self.use_cuda, ap=self.ap, speaker_id=speaker_id, + language_id=language_id, + language_name=language_idx, style_wav=style_wav, enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, use_griffin_lim=use_gl,