mirror of https://github.com/coqui-ai/TTS.git
Add support for multi-lingual models in CLI
This commit is contained in:
parent
7b81c16434
commit
4706583452
|
@ -148,12 +148,19 @@ def main():
|
||||||
|
|
||||||
# args for multi-speaker synthesis
|
# args for multi-speaker synthesis
|
||||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||||
|
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speaker_idx",
|
"--speaker_idx",
|
||||||
type=str,
|
type=str,
|
||||||
help="Target speaker ID for a multi-speaker TTS model.",
|
help="Target speaker ID for a multi-speaker TTS model.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language_idx",
|
||||||
|
type=str,
|
||||||
|
help="Target language ID for a multi-lingual TTS model.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speaker_wav",
|
"--speaker_wav",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
|
@ -169,6 +176,14 @@ def main():
|
||||||
const=True,
|
const=True,
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list_language_idxs",
|
||||||
|
help="List available language ids for the defined multi-lingual model.",
|
||||||
|
type=str2bool,
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
# aux args
|
# aux args
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_spectogram",
|
"--save_spectogram",
|
||||||
|
@ -180,7 +195,7 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# print the description if either text or list_models is not set
|
# print the description if either text or list_models is not set
|
||||||
if args.text is None and not args.list_models and not args.list_speaker_idxs:
|
if args.text is None and not args.list_models and not args.list_speaker_idxs and not args.list_language_idxs:
|
||||||
parser.parse_args(["-h"])
|
parser.parse_args(["-h"])
|
||||||
|
|
||||||
# load model manager
|
# load model manager
|
||||||
|
@ -190,6 +205,7 @@ def main():
|
||||||
model_path = None
|
model_path = None
|
||||||
config_path = None
|
config_path = None
|
||||||
speakers_file_path = None
|
speakers_file_path = None
|
||||||
|
language_ids_file_path = None
|
||||||
vocoder_path = None
|
vocoder_path = None
|
||||||
vocoder_config_path = None
|
vocoder_config_path = None
|
||||||
encoder_path = None
|
encoder_path = None
|
||||||
|
@ -213,6 +229,7 @@ def main():
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
config_path = args.config_path
|
config_path = args.config_path
|
||||||
speakers_file_path = args.speakers_file_path
|
speakers_file_path = args.speakers_file_path
|
||||||
|
language_ids_file_path = args.language_ids_file_path
|
||||||
|
|
||||||
if args.vocoder_path is not None:
|
if args.vocoder_path is not None:
|
||||||
vocoder_path = args.vocoder_path
|
vocoder_path = args.vocoder_path
|
||||||
|
@ -227,6 +244,7 @@ def main():
|
||||||
model_path,
|
model_path,
|
||||||
config_path,
|
config_path,
|
||||||
speakers_file_path,
|
speakers_file_path,
|
||||||
|
language_ids_file_path,
|
||||||
vocoder_path,
|
vocoder_path,
|
||||||
vocoder_config_path,
|
vocoder_config_path,
|
||||||
encoder_path,
|
encoder_path,
|
||||||
|
@ -242,6 +260,14 @@ def main():
|
||||||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# query langauge ids of a multi-lingual model.
|
||||||
|
if args.list_language_idxs:
|
||||||
|
print(
|
||||||
|
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
|
||||||
|
)
|
||||||
|
print(synthesizer.tts_model.language_manager.language_id_mapping)
|
||||||
|
return
|
||||||
|
|
||||||
# check the arguments against a multi-speaker model.
|
# check the arguments against a multi-speaker model.
|
||||||
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
|
||||||
print(
|
print(
|
||||||
|
@ -254,7 +280,7 @@ def main():
|
||||||
print(" > Text: {}".format(args.text))
|
print(" > Text: {}".format(args.text))
|
||||||
|
|
||||||
# kick it
|
# kick it
|
||||||
wav = synthesizer.tts(args.text, args.speaker_idx, args.speaker_wav, args.gst_style)
|
wav = synthesizer.tts(args.text, args.speaker_idx, args.language_idx, args.speaker_wav)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
print(" > Saving output to {}".format(args.out_path))
|
print(" > Saving output to {}".format(args.out_path))
|
||||||
|
|
|
@ -31,6 +31,7 @@ class LanguageManager:
|
||||||
language_ids_file_path: str = "",
|
language_ids_file_path: str = "",
|
||||||
config: Coqpit = None,
|
config: Coqpit = None,
|
||||||
):
|
):
|
||||||
|
self.language_id_mapping = {}
|
||||||
if language_ids_file_path:
|
if language_ids_file_path:
|
||||||
self.set_language_ids_from_file(language_ids_file_path)
|
self.set_language_ids_from_file(language_ids_file_path)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
@ -23,6 +24,7 @@ class Synthesizer(object):
|
||||||
tts_checkpoint: str,
|
tts_checkpoint: str,
|
||||||
tts_config_path: str,
|
tts_config_path: str,
|
||||||
tts_speakers_file: str = "",
|
tts_speakers_file: str = "",
|
||||||
|
tts_languages_file: str = "",
|
||||||
vocoder_checkpoint: str = "",
|
vocoder_checkpoint: str = "",
|
||||||
vocoder_config: str = "",
|
vocoder_config: str = "",
|
||||||
encoder_checkpoint: str = "",
|
encoder_checkpoint: str = "",
|
||||||
|
@ -52,6 +54,7 @@ class Synthesizer(object):
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
self.tts_config_path = tts_config_path
|
self.tts_config_path = tts_config_path
|
||||||
self.tts_speakers_file = tts_speakers_file
|
self.tts_speakers_file = tts_speakers_file
|
||||||
|
self.tts_languages_file = tts_languages_file
|
||||||
self.vocoder_checkpoint = vocoder_checkpoint
|
self.vocoder_checkpoint = vocoder_checkpoint
|
||||||
self.vocoder_config = vocoder_config
|
self.vocoder_config = vocoder_config
|
||||||
self.encoder_checkpoint = encoder_checkpoint
|
self.encoder_checkpoint = encoder_checkpoint
|
||||||
|
@ -63,6 +66,9 @@ class Synthesizer(object):
|
||||||
self.speaker_manager = None
|
self.speaker_manager = None
|
||||||
self.num_speakers = 0
|
self.num_speakers = 0
|
||||||
self.tts_speakers = {}
|
self.tts_speakers = {}
|
||||||
|
self.language_manager = None
|
||||||
|
self.num_languages = 0
|
||||||
|
self.tts_languages = {}
|
||||||
self.d_vector_dim = 0
|
self.d_vector_dim = 0
|
||||||
self.seg = self._get_segmenter("en")
|
self.seg = self._get_segmenter("en")
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
@ -110,8 +116,13 @@ class Synthesizer(object):
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||||
|
|
||||||
speaker_manager = self._init_speaker_manager()
|
speaker_manager = self._init_speaker_manager()
|
||||||
|
language_manager = self._init_language_manager()
|
||||||
|
|
||||||
self.tts_model = setup_tts_model(config=self.tts_config, speaker_manager=speaker_manager)
|
self.tts_model = setup_tts_model(
|
||||||
|
config=self.tts_config,
|
||||||
|
speaker_manager=speaker_manager,
|
||||||
|
language_manager=language_manager,
|
||||||
|
)
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
@ -133,6 +144,17 @@ class Synthesizer(object):
|
||||||
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
def _init_language_manager(self):
|
||||||
|
"""Initialize the LanguageManager"""
|
||||||
|
# setup if multi-lingual settings are in the global model config
|
||||||
|
language_manager = None
|
||||||
|
if hasattr(self.tts_config, "use_language_embedding") and self.tts_config.use_language_embedding is True:
|
||||||
|
if self.tts_languages_file:
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
||||||
|
elif self.tts_config.get("language_ids_file", None):
|
||||||
|
language_manager = LanguageManager(language_ids_file_path=self.tts_config.language_ids_file)
|
||||||
|
return language_manager
|
||||||
|
|
||||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
"""Load the vocoder model.
|
"""Load the vocoder model.
|
||||||
|
|
||||||
|
@ -174,12 +196,20 @@ class Synthesizer(object):
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(self, text: str, speaker_idx: str = "", speaker_wav=None, style_wav=None) -> List[int]:
|
def tts(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
speaker_idx: str = "",
|
||||||
|
language_idx: str = "",
|
||||||
|
speaker_wav=None,
|
||||||
|
style_wav=None
|
||||||
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): input text.
|
text (str): input text.
|
||||||
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||||
|
language_idx (str, optional): language id for multi-language models. Defaults to "".
|
||||||
speaker_wav ():
|
speaker_wav ():
|
||||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||||
|
|
||||||
|
@ -219,6 +249,24 @@ class Synthesizer(object):
|
||||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# handle multi-lingaul
|
||||||
|
language_id = None
|
||||||
|
if self.tts_languages_file or hasattr(self.tts_model.language_manager, "language_id_mapping"):
|
||||||
|
if language_idx and isinstance(language_idx, str):
|
||||||
|
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
||||||
|
|
||||||
|
elif not language_idx:
|
||||||
|
raise ValueError(
|
||||||
|
" [!] Look like you use a multi-lingual model. "
|
||||||
|
"You need to define either a `language_idx` or a `style_wav` to use a multi-lingual model."
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f" [!] Missing language_ids.json file path for selecting language {language_idx}."
|
||||||
|
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||||
|
)
|
||||||
|
|
||||||
# compute a new d_vector from the given clip.
|
# compute a new d_vector from the given clip.
|
||||||
if speaker_wav is not None:
|
if speaker_wav is not None:
|
||||||
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||||
|
@ -234,6 +282,8 @@ class Synthesizer(object):
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
|
language_id=language_id,
|
||||||
|
language_name=language_idx,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
|
|
Loading…
Reference in New Issue