mirror of https://github.com/coqui-ai/TTS.git
Add support for multi-lingual models in CLI
This commit is contained in:
parent
7b81c16434
commit
4706583452
|
@ -148,12 +148,19 @@ def main():
|
|||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||
parser.add_argument(
|
||||
"--speaker_idx",
|
||||
type=str,
|
||||
help="Target speaker ID for a multi-speaker TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language_idx",
|
||||
type=str,
|
||||
help="Target language ID for a multi-lingual TTS model.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speaker_wav",
|
||||
nargs="+",
|
||||
|
@ -169,6 +176,14 @@ def main():
|
|||
const=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_language_idxs",
|
||||
help="List available language ids for the defined multi-lingual model.",
|
||||
type=str2bool,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
)
|
||||
# aux args
|
||||
parser.add_argument(
|
||||
"--save_spectogram",
|
||||
|
@ -180,7 +195,7 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
if args.text is None and not args.list_models and not args.list_speaker_idxs:
|
||||
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
# load model manager
|
||||
|
@ -190,6 +205,7 @@ def main():
|
|||
model_path = None
|
||||
config_path = None
|
||||
speakers_file_path = None
|
||||
language_ids_file_path = None
|
||||
vocoder_path = None
|
||||
vocoder_config_path = None
|
||||
encoder_path = None
|
||||
|
@ -213,6 +229,7 @@ def main():
|
|||
model_path = args.model_path
|
||||
config_path = args.config_path
|
||||
speakers_file_path = args.speakers_file_path
|
||||
language_ids_file_path = args.language_ids_file_path
|
||||
|
||||
if args.vocoder_path is not None:
|
||||
vocoder_path = args.vocoder_path
|
||||
|
@ -227,6 +244,7 @@ def main():
|
|||
model_path,
|
||||
config_path,
|
||||
speakers_file_path,
|
||||
language_ids_file_path,
|
||||
vocoder_path,
|
||||
vocoder_config_path,
|
||||
encoder_path,
|
||||
|
@ -242,6 +260,14 @@ def main():
|
|||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||
return
|
||||
|
||||
# query langauge ids of a multi-lingual model.
|
||||
if args.list_language_idxs:
|
||||
print(
|
||||
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||
)
|
||||
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||
print(
|
||||
|
@ -254,7 +280,7 @@ def main():
|
|||
print(" > Text: {}".format(args.text))
|
||||
|
||||
# kick it
|
||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
|
||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
|
|
|
@ -31,6 +31,7 @@ class LanguageManager:
|
|||
language_ids_file_path: str = "",
|
||||
config: Coqpit = None,
|
||||
):
|
||||
self.language_id_mapping = {}
|
||||
if language_ids_file_path:
|
||||
self.set_language_ids_from_file(language_ids_file_path)
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
|
@ -23,6 +24,7 @@ class Synthesizer(object):
|
|||
tts_checkpoint: str,
|
||||
tts_config_path: str,
|
||||
tts_speakers_file: str = "",
|
||||
tts_languages_file: str = "",
|
||||
vocoder_checkpoint: str = "",
|
||||
vocoder_config: str = "",
|
||||
encoder_checkpoint: str = "",
|
||||
|
@ -52,6 +54,7 @@ class Synthesizer(object):
|
|||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config_path = tts_config_path
|
||||
self.tts_speakers_file = tts_speakers_file
|
||||
self.tts_languages_file = tts_languages_file
|
||||
self.vocoder_checkpoint = vocoder_checkpoint
|
||||
self.vocoder_config = vocoder_config
|
||||
self.encoder_checkpoint = encoder_checkpoint
|
||||
|
@ -63,6 +66,9 @@ class Synthesizer(object):
|
|||
self.speaker_manager = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = {}
|
||||
self.language_manager = None
|
||||
self.num_languages = 0
|
||||
self.tts_languages = {}
|
||||
self.d_vector_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
|
@ -110,8 +116,13 @@ class Synthesizer(object):
|
|||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||
|
||||
speaker_manager = self._init_speaker_manager()
|
||||
language_manager = self._init_language_manager()
|
||||
|
||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
||||
self.tts_model = setup_tts_model(
|
||||
config=self.tts_config,
|
||||
speaker_manager=speaker_manager,
|
||||
language_manager=language_manager,
|
||||
)
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
@ -133,6 +144,17 @@ class Synthesizer(object):
|
|||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||
return speaker_manager
|
||||
|
||||
def _init_language_manager(self):
|
||||
"""Initialize the LanguageManager"""
|
||||
# setup if multi-lingual settings are in the global model config
|
||||
language_manager = None
|
||||
if hasattr(self.tts_config, "use_language_embedding") and self.tts_config.use_language_embedding is True:
|
||||
if self.tts_languages_file:
|
||||
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
||||
elif self.tts_config.get("language_ids_file", None):
|
||||
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
|
||||
return language_manager
|
||||
|
||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||
"""Load the vocoder model.
|
||||
|
||||
|
@ -174,12 +196,20 @@ class Synthesizer(object):
|
|||
wav = np.array(wav)
|
||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||
|
||||
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
|
||||
def tts(
|
||||
self,
|
||||
text: str,
|
||||
speaker_idx: str = "",
|
||||
language_idx: str = "",
|
||||
speaker_wav=None,
|
||||
style_wav=None
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
Args:
|
||||
text (str): input text.
|
||||
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||
language_idx (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav ():
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
|
||||
|
@ -219,6 +249,24 @@ class Synthesizer(object):
|
|||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||
)
|
||||
|
||||
# handle multi-lingaul
|
||||
language_id = None
|
||||
if self.tts_languages_file or hasattr(self.tts_model.language_manager, "language_id_mapping"):
|
||||
if language_idx and isinstance(language_idx, str):
|
||||
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
||||
|
||||
elif not language_idx:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-lingual model. "
|
||||
"You need to define either a `language_idx` or a `style_wav` to use a multi-lingual model."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f" [!] Missing language_ids.json file path for selecting language {language_idx}."
|
||||
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||
)
|
||||
|
||||
# compute a new d_vector from the given clip.
|
||||
if speaker_wav is not None:
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||
|
@ -234,6 +282,8 @@ class Synthesizer(object):
|
|||
use_cuda=self.use_cuda,
|
||||
ap=self.ap,
|
||||
speaker_id=speaker_id,
|
||||
language_id=language_id,
|
||||
language_name=language_idx,
|
||||
style_wav=style_wav,
|
||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||
use_griffin_lim=use_gl,
|
||||
|
|
Loading…
Reference in New Issue