Fix find unique phonemes script (#1928)

* Fix find unique phonemes script

* Fix unit tests
This commit is contained in:
Edresson Casanova 2022-09-08 05:17:35 -03:00 committed by GitHub
parent 3b7dff568a
commit 159eeeef64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 14 deletions

View File

@ -7,30 +7,25 @@ from tqdm.contrib.concurrent import process_map
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut from TTS.tts.utils.text.phonemizers import Gruut
phonemizer = Gruut(language="en-us")
def compute_phonemes(item): def compute_phonemes(item):
try: text = item["text"]
text = item[0] ph = phonemizer.phonemize(text).replace("|", "")
ph = phonemizer.phonemize(text).split("|") return set(list(ph))
except:
return []
return list(set(ph))
def main(): def main():
# pylint: disable=W0601 # pylint: disable=W0601
global c global c, phonemizer
# pylint: disable=bad-option-value # pylint: disable=bad-option-value
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n""" description="""Find all the unique characters or phonemes in a dataset.\n\n"""
""" """
Example runs: Example runs:
python TTS/bin/find_unique_chars.py --config_path config.json python TTS/bin/find_unique_phonemes.py --config_path config.json
""", """,
formatter_class=RawTextHelpFormatter, formatter_class=RawTextHelpFormatter,
) )
@ -46,15 +41,24 @@ def main():
items = train_items + eval_items items = train_items + eval_items
print("Num items:", len(items)) print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items) language_list = [item["language"] for item in items]
is_lang_def = all(language_list)
if not c.phoneme_language or not is_lang_def: if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.") raise ValueError("Phoneme language must be defined in config.")
if not language_list.count(language_list[0]) == len(language_list):
raise ValueError(
"Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!"
)
phonemizer = Gruut(language=language_list[0], keep_puncs=True)
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = [] phones = []
for ph in phonemes: for ph in phonemes:
phones.extend(ph) phones.extend(ph)
phones = set(phones) phones = set(phones)
lower_phones = filter(lambda c: c.islower(), phones) lower_phones = filter(lambda c: c.islower(), phones)
phones_force_lower = [c.lower() for c in phones] phones_force_lower = [c.lower() for c in phones]

View File

@ -19,6 +19,7 @@ dataset_config_en = BaseDatasetConfig(
language="en", language="en",
) )
"""
dataset_config_pt = BaseDatasetConfig( dataset_config_pt = BaseDatasetConfig(
name="ljspeech", name="ljspeech",
meta_file_train="metadata.csv", meta_file_train="metadata.csv",
@ -26,6 +27,7 @@ dataset_config_pt = BaseDatasetConfig(
path="tests/data/ljspeech", path="tests/data/ljspeech",
language="pt-br", language="pt-br",
) )
"""
# pylint: disable=protected-access # pylint: disable=protected-access
class TestFindUniquePhonemes(unittest.TestCase): class TestFindUniquePhonemes(unittest.TestCase):
@ -46,7 +48,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
epochs=1, epochs=1,
print_step=1, print_step=1,
print_eval=True, print_eval=True,
datasets=[dataset_config_en, dataset_config_pt], datasets=[dataset_config_en],
) )
config.save_json(config_path) config.save_json(config_path)
@ -70,7 +72,7 @@ class TestFindUniquePhonemes(unittest.TestCase):
epochs=1, epochs=1,
print_step=1, print_step=1,
print_eval=True, print_eval=True,
datasets=[dataset_config_en, dataset_config_pt], datasets=[dataset_config_en],
) )
config.save_json(config_path) config.save_json(config_path)