Multilingual tokenizer (#2229)

* Implement multilingual tokenizer

* Add multi_phonemizer receipe

* Fix lint

* Add TestMultiPhonemizer

* Fix lint

* make style
This commit is contained in:
Julian Weber 2023-01-02 10:03:19 +01:00 committed by GitHub
parent f814d52394
commit a07397733b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 230 additions and 31 deletions

View File

@ -212,6 +212,9 @@ class BaseDatasetConfig(Coqpit):
language (str):
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
phonemizer (str):
Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
meta_file_val (str):
Name of the dataset meta file that defines the instances used at validation.
@ -226,6 +229,7 @@ class BaseDatasetConfig(Coqpit):
meta_file_train: str = ""
ignored_speakers: List[str] = None
language: str = ""
phonemizer: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""

View File

@ -569,14 +569,14 @@ class PhonemeDataset(Dataset):
def __getitem__(self, index):
item = self.samples[index]
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"])
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self):
return len(self.samples)
def compute_or_load(self, file_name, text):
def compute_or_load(self, file_name, text, language):
"""Compute phonemes for the given text.
If the phonemes are already cached, load them from cache.
@ -586,7 +586,7 @@ class PhonemeDataset(Dataset):
try:
ids = np.load(cache_path)
except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text)
ids = self.tokenizer.text_to_ids(text, language=language)
np.save(cache_path, ids)
return ids

View File

@ -175,9 +175,15 @@ def synthesis(
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
language_name = None
if language_id is not None:
language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
assert len(language) == 1, "language_id must be a valid language"
language_name = language[0]
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),
model.tokenizer.text_to_ids(text, language=language_name),
dtype=np.int32,
)
# pass tensors to backend

View File

@ -114,7 +114,7 @@ class BasePhonemizer(abc.ABC):
return self._punctuator.restore(phonemized, punctuations)[0]
return phonemized[0]
def phonemize(self, text: str, separator="|") -> str:
def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
"""Returns the `text` phonemized for the given language
Args:

View File

@ -43,7 +43,7 @@ class JA_JP_Phonemizer(BasePhonemizer):
return separator.join(ph)
return ph
def phonemize(self, text: str, separator="|") -> str:
def phonemize(self, text: str, separator="|", language=None) -> str:
"""Custom phonemize for JP_JA
Skip pre-post processing steps used by the other phonemizers.

View File

@ -40,7 +40,7 @@ class KO_KR_Phonemizer(BasePhonemizer):
return separator.join(ph)
return ph
def phonemize(self, text: str, separator: str = "", character: str = "hangeul") -> str:
def phonemize(self, text: str, separator: str = "", character: str = "hangeul", language=None) -> str:
return self._phonemize(text, separator, character)
@staticmethod

View File

@ -14,30 +14,40 @@ class MultiPhonemizer:
TODO: find a way to pass custom kwargs to the phonemizers
"""
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER
language = "multi-lingual"
lang_to_phonemizer = {}
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer)
def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value
for k, v in lang_to_phonemizer_name.items():
if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys():
lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k]
elif v == "":
raise ValueError(f"Phonemizer wasn't set for language {k} and doesn't have a default.")
self.lang_to_phonemizer_name = lang_to_phonemizer_name
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items():
phonemizer = get_phonemizer_by_name(v, language=k)
lang_to_phonemizer[k] = phonemizer
lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k)
return lang_to_phonemizer
@staticmethod
def name():
return "multi-phonemizer"
def phonemize(self, text, language, separator="|"):
def phonemize(self, text, separator="|", language=""):
if language == "":
raise ValueError("Language must be set for multi-phonemizer to phonemize.")
return self.lang_to_phonemizer[language].phonemize(text, separator)
def supported_languages(self) -> List:
return list(self.lang_to_phonemizer_name.keys())
return list(self.lang_to_phonemizer.keys())
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}")
print(f"{indent}| > phoneme backend: {self.name()}")
# if __name__ == "__main__":
@ -48,7 +58,7 @@ class MultiPhonemizer:
# "zh-cn": "这是中国的例子",
# }
# phonemes = {}
# ph = MultiPhonemizer()
# ph = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
# for lang, text in texts.items():
# phoneme = ph.phonemize(text, lang)
# phonemes[lang] = phoneme

View File

@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class
@ -106,7 +107,7 @@ class TTSTokenizer:
if self.text_cleaner is not None:
text = self.text_cleaner(text)
if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="")
text = self.phonemizer.phonemize(text, separator="", language=language)
if self.add_blank:
text = self.intersperse_blank_char(text, True)
if self.use_eos_bos:
@ -182,8 +183,16 @@ class TTSTokenizer:
# init phonemizer
phonemizer = None
if config.use_phonemes:
if "phonemizer" in config and config.phonemizer == "multi_phonemizer":
lang_to_phonemizer_name = {}
for dataset in config.datasets:
if dataset.language != "":
lang_to_phonemizer_name[dataset.language] = dataset.phonemizer
else:
raise ValueError("Multi phonemizer requires language to be set for each dataset.")
phonemizer = MultiPhonemizer(lang_to_phonemizer_name)
else:
phonemizer_kwargs = {"language": config.phoneme_language}
if "phonemizer" in config and config.phonemizer:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
else:

View File

@ -0,0 +1,126 @@
import os
from glob import glob
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = "/media/julian/Workdisk/train"
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [
BaseDatasetConfig(
formatter="mailabs",
meta_file_train=None,
path=path,
language=path.split("/")[-1], # language code is the folder name
)
for path in dataset_paths
]
audio_config = VitsAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
mel_fmin=0,
mel_fmax=None,
)
vitsArgs = VitsArgs(
use_language_embedding=True,
embedded_language_dim=4,
use_speaker_embedding=True,
use_sdp=False,
)
config = VitsConfig(
model_args=vitsArgs,
audio=audio_config,
run_name="vits_vctk",
use_speaker_embedding=True,
batch_size=32,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=12,
num_eval_loader_workers=12,
precompute_num_workers=12,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=True,
phoneme_language=None,
phonemizer="multi_phonemizer",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
min_audio_len=audio_config.sample_rate,
max_audio_len=audio_config.sample_rate * 10,
output_path=output_path,
datasets=dataset_config,
test_sentences=[
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en-us",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr-fr",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de-de"],
["Я думаю, что этот стартап действительно удивительный.", "nikolaev", None, "ru"],
],
)
# force the convertion of the custom characters to a config attribute
config.from_dict(config.to_dict())
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)
config.model_args.num_languages = language_manager.num_languages
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# config is updated with the default characters if not defined in the config.
tokenizer, config = TTSTokenizer.init_from_config(config)
# init model
model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -2,6 +2,7 @@ import unittest
from distutils.version import LooseVersion
from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
EXAMPLE_TEXTs = [
"Recent research at Harvard has shown meditating",
@ -226,3 +227,46 @@ class TestZH_CN_Phonemizer(unittest.TestCase):
def test_is_available(self):
self.assertTrue(self.phonemizer.is_available())
class TestMultiPhonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
def test_phonemize(self):
# Enlish espeak
text = "Be a voice, not an! echo?"
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?"
output = self.phonemizer.phonemize(text, separator="|", language="en-us")
output = output.replace("|", "")
self.assertEqual(output, gt)
# German gruut
text = "Hallo, das ist ein Deutches Beipiel!"
gt = "haloː, das ɪst aeːn dɔɔʏ̯tçəs bəʔiːpiːl!"
output = self.phonemizer.phonemize(text, separator="|", language="de")
output = output.replace("|", "")
self.assertEqual(output, gt)
def test_phonemizer_initialization(self):
# test with unsupported language
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "xx": ""})
# test with unsupported phonemizer
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "fr": "xx"})
def test_sub_phonemizers(self):
for lang in self.phonemizer.lang_to_phonemizer_name.keys():
self.assertEqual(lang, self.phonemizer.lang_to_phonemizer[lang].language)
self.assertEqual(
self.phonemizer.lang_to_phonemizer_name[lang], self.phonemizer.lang_to_phonemizer[lang].name()
)
def test_name(self):
self.assertEqual(self.phonemizer.name(), "multi-phonemizer")
def test_get_supported_languages(self):
self.assertIsInstance(self.phonemizer.supported_languages(), list)