diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 0ae74bd4..4bd7a78e 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -7,30 +7,25 @@ from tqdm.contrib.concurrent import process_map from TTS.config import load_config from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut - -phonemizer = Gruut(language="en-us") +from TTS.tts.utils.text.phonemizers import Gruut def compute_phonemes(item): - try: - text = item[0] - ph = phonemizer.phonemize(text).split("|") - except: - return [] - return list(set(ph)) + text = item["text"] + ph = phonemizer.phonemize(text).replace("|", "") + return set(list(ph)) def main(): # pylint: disable=W0601 - global c + global c, phonemizer # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Find all the unique characters or phonemes in a dataset.\n\n""" """ Example runs: - python TTS/bin/find_unique_chars.py --config_path config.json + python TTS/bin/find_unique_phonemes.py --config_path config.json """, formatter_class=RawTextHelpFormatter, ) @@ -46,15 +41,24 @@ def main(): items = train_items + eval_items print("Num items:", len(items)) - is_lang_def = all(item["language"] for item in items) + language_list = [item["language"] for item in items] + is_lang_def = all(language_list) if not c.phoneme_language or not is_lang_def: raise ValueError("Phoneme language must be defined in config.") + if not language_list.count(language_list[0]) == len(language_list): + raise ValueError( + "Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!" + ) + + phonemizer = Gruut(language=language_list[0], keep_puncs=True) + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phones = [] for ph in phonemes: phones.extend(ph) + phones = set(phones) lower_phones = filter(lambda c: c.islower(), phones) phones_force_lower = [c.lower() for c in phones] diff --git a/tests/aux_tests/test_find_unique_phonemes.py b/tests/aux_tests/test_find_unique_phonemes.py index fa740ba3..e9f8e2e0 100644 --- a/tests/aux_tests/test_find_unique_phonemes.py +++ b/tests/aux_tests/test_find_unique_phonemes.py @@ -19,6 +19,7 @@ dataset_config_en = BaseDatasetConfig( language="en", ) +""" dataset_config_pt = BaseDatasetConfig( name="ljspeech", meta_file_train="metadata.csv", @@ -26,6 +27,7 @@ dataset_config_pt = BaseDatasetConfig( path="tests/data/ljspeech", language="pt-br", ) +""" # pylint: disable=protected-access class TestFindUniquePhonemes(unittest.TestCase): @@ -46,7 +48,7 @@ class TestFindUniquePhonemes(unittest.TestCase): epochs=1, print_step=1, print_eval=True, - datasets=[dataset_config_en, dataset_config_pt], + datasets=[dataset_config_en], ) config.save_json(config_path) @@ -70,7 +72,7 @@ class TestFindUniquePhonemes(unittest.TestCase): epochs=1, print_step=1, print_eval=True, - datasets=[dataset_config_en, dataset_config_pt], + datasets=[dataset_config_en], ) config.save_json(config_path)