Multilingual tokenizer (#2229)

* Implement multilingual tokenizer

* Add multi_phonemizer receipe

* Fix lint

* Add TestMultiPhonemizer

* Fix lint

* make style
This commit is contained in:
Julian Weber 2023-01-02 10:03:19 +01:00 committed by GitHub
parent f814d52394
commit a07397733b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 230 additions and 31 deletions

View File

@ -212,6 +212,9 @@ class BaseDatasetConfig(Coqpit):
language (str): language (str):
Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`. Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
phonemizer (str):
Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
meta_file_val (str): meta_file_val (str):
Name of the dataset meta file that defines the instances used at validation. Name of the dataset meta file that defines the instances used at validation.
@ -226,6 +229,7 @@ class BaseDatasetConfig(Coqpit):
meta_file_train: str = "" meta_file_train: str = ""
ignored_speakers: List[str] = None ignored_speakers: List[str] = None
language: str = "" language: str = ""
phonemizer: str = ""
meta_file_val: str = "" meta_file_val: str = ""
meta_file_attn_mask: str = "" meta_file_attn_mask: str = ""

View File

@ -569,14 +569,14 @@ class PhonemeDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
item = self.samples[index] item = self.samples[index]
ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"]) ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
ph_hat = self.tokenizer.ids_to_text(ids) ph_hat = self.tokenizer.ids_to_text(ids)
return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def compute_or_load(self, file_name, text): def compute_or_load(self, file_name, text, language):
"""Compute phonemes for the given text. """Compute phonemes for the given text.
If the phonemes are already cached, load them from cache. If the phonemes are already cached, load them from cache.
@ -586,7 +586,7 @@ class PhonemeDataset(Dataset):
try: try:
ids = np.load(cache_path) ids = np.load(cache_path)
except FileNotFoundError: except FileNotFoundError:
ids = self.tokenizer.text_to_ids(text) ids = self.tokenizer.text_to_ids(text, language=language)
np.save(cache_path, ids) np.save(cache_path, ids)
return ids return ids

View File

@ -175,9 +175,15 @@ def synthesis(
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = style_mel.transpose(1, 2) # [1, time, depth] style_mel = style_mel.transpose(1, 2) # [1, time, depth]
language_name = None
if language_id is not None:
language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id]
assert len(language) == 1, "language_id must be a valid language"
language_name = language[0]
# convert text to sequence of token IDs # convert text to sequence of token IDs
text_inputs = np.asarray( text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id), model.tokenizer.text_to_ids(text, language=language_name),
dtype=np.int32, dtype=np.int32,
) )
# pass tensors to backend # pass tensors to backend

View File

@ -114,7 +114,7 @@ class BasePhonemizer(abc.ABC):
return self._punctuator.restore(phonemized, punctuations)[0] return self._punctuator.restore(phonemized, punctuations)[0]
return phonemized[0] return phonemized[0]
def phonemize(self, text: str, separator="|") -> str: def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument
"""Returns the `text` phonemized for the given language """Returns the `text` phonemized for the given language
Args: Args:

View File

@ -43,7 +43,7 @@ class JA_JP_Phonemizer(BasePhonemizer):
return separator.join(ph) return separator.join(ph)
return ph return ph
def phonemize(self, text: str, separator="|") -> str: def phonemize(self, text: str, separator="|", language=None) -> str:
"""Custom phonemize for JP_JA """Custom phonemize for JP_JA
Skip pre-post processing steps used by the other phonemizers. Skip pre-post processing steps used by the other phonemizers.

View File

@ -40,7 +40,7 @@ class KO_KR_Phonemizer(BasePhonemizer):
return separator.join(ph) return separator.join(ph)
return ph return ph
def phonemize(self, text: str, separator: str = "", character: str = "hangeul") -> str: def phonemize(self, text: str, separator: str = "", character: str = "hangeul", language=None) -> str:
return self._phonemize(text, separator, character) return self._phonemize(text, separator, character)
@staticmethod @staticmethod

View File

@ -14,30 +14,40 @@ class MultiPhonemizer:
TODO: find a way to pass custom kwargs to the phonemizers TODO: find a way to pass custom kwargs to the phonemizers
""" """
lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER lang_to_phonemizer = {}
language = "multi-lingual"
def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value
self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer) for k, v in lang_to_phonemizer_name.items():
if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys():
lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k]
elif v == "":
raise ValueError(f"Phonemizer wasn't set for language {k} and doesn't have a default.")
self.lang_to_phonemizer_name = lang_to_phonemizer_name
self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name)
@staticmethod @staticmethod
def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict:
lang_to_phonemizer = {} lang_to_phonemizer = {}
for k, v in lang_to_phonemizer_name.items(): for k, v in lang_to_phonemizer_name.items():
phonemizer = get_phonemizer_by_name(v, language=k) lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k)
lang_to_phonemizer[k] = phonemizer
return lang_to_phonemizer return lang_to_phonemizer
@staticmethod @staticmethod
def name(): def name():
return "multi-phonemizer" return "multi-phonemizer"
def phonemize(self, text, language, separator="|"): def phonemize(self, text, separator="|", language=""):
if language == "":
raise ValueError("Language must be set for multi-phonemizer to phonemize.")
return self.lang_to_phonemizer[language].phonemize(text, separator) return self.lang_to_phonemizer[language].phonemize(text, separator)
def supported_languages(self) -> List: def supported_languages(self) -> List:
return list(self.lang_to_phonemizer_name.keys()) return list(self.lang_to_phonemizer.keys())
def print_logs(self, level: int = 0):
indent = "\t" * level
print(f"{indent}| > phoneme language: {self.supported_languages()}")
print(f"{indent}| > phoneme backend: {self.name()}")
# if __name__ == "__main__": # if __name__ == "__main__":
@ -48,7 +58,7 @@ class MultiPhonemizer:
# "zh-cn": "这是中国的例子", # "zh-cn": "这是中国的例子",
# } # }
# phonemes = {} # phonemes = {}
# ph = MultiPhonemizer() # ph = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
# for lang, text in texts.items(): # for lang, text in texts.items():
# phoneme = ph.phonemize(text, lang) # phoneme = ph.phonemize(text, lang)
# phonemes[lang] = phoneme # phonemes[lang] = phoneme

View File

@ -3,6 +3,7 @@ from typing import Callable, Dict, List, Union
from TTS.tts.utils.text import cleaners from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
from TTS.utils.generic_utils import get_import_path, import_class from TTS.utils.generic_utils import get_import_path, import_class
@ -106,7 +107,7 @@ class TTSTokenizer:
if self.text_cleaner is not None: if self.text_cleaner is not None:
text = self.text_cleaner(text) text = self.text_cleaner(text)
if self.use_phonemes: if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="") text = self.phonemizer.phonemize(text, separator="", language=language)
if self.add_blank: if self.add_blank:
text = self.intersperse_blank_char(text, True) text = self.intersperse_blank_char(text, True)
if self.use_eos_bos: if self.use_eos_bos:
@ -182,21 +183,29 @@ class TTSTokenizer:
# init phonemizer # init phonemizer
phonemizer = None phonemizer = None
if config.use_phonemes: if config.use_phonemes:
phonemizer_kwargs = {"language": config.phoneme_language} if "phonemizer" in config and config.phonemizer == "multi_phonemizer":
lang_to_phonemizer_name = {}
if "phonemizer" in config and config.phonemizer: for dataset in config.datasets:
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) if dataset.language != "":
lang_to_phonemizer_name[dataset.language] = dataset.phonemizer
else:
raise ValueError("Multi phonemizer requires language to be set for each dataset.")
phonemizer = MultiPhonemizer(lang_to_phonemizer_name)
else: else:
try: phonemizer_kwargs = {"language": config.phoneme_language}
phonemizer = get_phonemizer_by_name( if "phonemizer" in config and config.phonemizer:
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
) else:
new_config.phonemizer = phonemizer.name() try:
except KeyError as e: phonemizer = get_phonemizer_by_name(
raise ValueError( DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
f"""No phonemizer found for language {config.phoneme_language}. )
You may need to install a third party library for this language.""" new_config.phonemizer = phonemizer.name()
) from e except KeyError as e:
raise ValueError(
f"""No phonemizer found for language {config.phoneme_language}.
You may need to install a third party library for this language."""
) from e
return ( return (
TTSTokenizer( TTSTokenizer(

View File

@ -0,0 +1,126 @@
import os
from glob import glob
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = "/media/julian/Workdisk/train"
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [
BaseDatasetConfig(
formatter="mailabs",
meta_file_train=None,
path=path,
language=path.split("/")[-1], # language code is the folder name
)
for path in dataset_paths
]
audio_config = VitsAudioConfig(
sample_rate=16000,
win_length=1024,
hop_length=256,
num_mels=80,
mel_fmin=0,
mel_fmax=None,
)
vitsArgs = VitsArgs(
use_language_embedding=True,
embedded_language_dim=4,
use_speaker_embedding=True,
use_sdp=False,
)
config = VitsConfig(
model_args=vitsArgs,
audio=audio_config,
run_name="vits_vctk",
use_speaker_embedding=True,
batch_size=32,
eval_batch_size=16,
batch_group_size=0,
num_loader_workers=12,
num_eval_loader_workers=12,
precompute_num_workers=12,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="multilingual_cleaners",
use_phonemes=True,
phoneme_language=None,
phonemizer="multi_phonemizer",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
min_audio_len=audio_config.sample_rate,
max_audio_len=audio_config.sample_rate * 10,
output_path=output_path,
datasets=dataset_config,
test_sentences=[
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en-us",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr-fr",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de-de"],
["Я думаю, что этот стартап действительно удивительный.", "nikolaev", None, "ru"],
],
)
# force the convertion of the custom characters to a config attribute
config.from_dict(config.to_dict())
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)
config.model_args.num_languages = language_manager.num_languages
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# config is updated with the default characters if not defined in the config.
tokenizer, config = TTSTokenizer.init_from_config(config)
# init model
model = Vits(config, ap, tokenizer, speaker_manager, language_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -2,6 +2,7 @@ import unittest
from distutils.version import LooseVersion from distutils.version import LooseVersion
from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer from TTS.tts.utils.text.phonemizers import ESpeak, Gruut, JA_JP_Phonemizer, ZH_CN_Phonemizer
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
EXAMPLE_TEXTs = [ EXAMPLE_TEXTs = [
"Recent research at Harvard has shown meditating", "Recent research at Harvard has shown meditating",
@ -226,3 +227,46 @@ class TestZH_CN_Phonemizer(unittest.TestCase):
def test_is_available(self): def test_is_available(self):
self.assertTrue(self.phonemizer.is_available()) self.assertTrue(self.phonemizer.is_available())
class TestMultiPhonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""})
def test_phonemize(self):
# Enlish espeak
text = "Be a voice, not an! echo?"
gt = "biː ɐ vˈɔɪs, nˈɑːt æn! ˈɛkoʊ?"
output = self.phonemizer.phonemize(text, separator="|", language="en-us")
output = output.replace("|", "")
self.assertEqual(output, gt)
# German gruut
text = "Hallo, das ist ein Deutches Beipiel!"
gt = "haloː, das ɪst aeːn dɔɔʏ̯tçəs bəʔiːpiːl!"
output = self.phonemizer.phonemize(text, separator="|", language="de")
output = output.replace("|", "")
self.assertEqual(output, gt)
def test_phonemizer_initialization(self):
# test with unsupported language
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "xx": ""})
# test with unsupported phonemizer
with self.assertRaises(ValueError):
MultiPhonemizer({"tr": "espeak", "fr": "xx"})
def test_sub_phonemizers(self):
for lang in self.phonemizer.lang_to_phonemizer_name.keys():
self.assertEqual(lang, self.phonemizer.lang_to_phonemizer[lang].language)
self.assertEqual(
self.phonemizer.lang_to_phonemizer_name[lang], self.phonemizer.lang_to_phonemizer[lang].name()
)
def test_name(self):
self.assertEqual(self.phonemizer.name(), "multi-phonemizer")
def test_get_supported_languages(self):
self.assertIsInstance(self.phonemizer.supported_languages(), list)