From 5176ae9e53cae7c71a2f19f240bde8752d4b038e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 7 Jan 2022 15:38:08 +0000 Subject: [PATCH] Fixes small compat. issues --- TTS/tts/datasets/__init__.py | 4 +-- TTS/tts/datasets/dataset.py | 2 +- TTS/tts/datasets/formatters.py | 2 +- TTS/tts/models/base_tts.py | 8 +++-- TTS/tts/utils/languages.py | 9 ++++++ TTS/tts/utils/speakers.py | 30 ++++++++++++------- recipes/vctk/fast_pitch/train_fast_pitch.py | 1 - recipes/vctk/fast_speech/train_fast_speech.py | 2 +- recipes/vctk/glow_tts/train_glow_tts.py | 2 +- tests/text_tests/test_phonemizer.py | 9 +++--- 10 files changed, 44 insertions(+), 25 deletions(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index f0a6ea95..d80e92c9 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -111,8 +111,8 @@ def load_tts_samples( meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for the duration predictor training - if d.meta_file_attn_mask: - meta_data = dict(load_attention_mask_meta_data(d["meta_file_attn_mask"])) + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins["audio_file"]].strip() meta_data_train_all[idx].update({"alignment_file": attn_file}) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index dee719ef..a98afc95 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -13,7 +13,7 @@ from TTS.utils.audio import AudioProcessor # to prevent too many open files error as suggested here # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 -torch.multiprocessing.set_sharing_strategy('file_system') +torch.multiprocessing.set_sharing_strategy("file_system") def _parse_sample(item): diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 546c3cc3..5168dd06 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -291,7 +291,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None): """https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip""" - file_ext = 'flac' + file_ext = "flac" test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 59862322..9a6a56df 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -261,7 +261,7 @@ class BaseTTS(BaseModel): speaker_id_mapping = None d_vector_mapping = None - # setup custom symbols if needed + # setup multi-lingual attributes if hasattr(self, "language_manager"): language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None @@ -290,6 +290,7 @@ class BaseTTS(BaseModel): speaker_id_mapping=speaker_id_mapping, d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, tokenizer=self.tokenizer, + language_id_mapping=language_id_mapping, ) # wait all the DDP process to be ready @@ -303,6 +304,7 @@ class BaseTTS(BaseModel): sampler = DistributedSampler(dataset) if num_gpus > 1 else None # Weighted samplers + # TODO: make this DDP amenable assert not ( num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False) ), "language_weighted_sampler is not supported with DistributedSampler" @@ -313,10 +315,10 @@ class BaseTTS(BaseModel): if sampler is None: if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_language_weighted_sampler(dataset.items) + sampler = get_language_weighted_sampler(dataset.samples) elif getattr(config, "use_speaker_weighted_sampler", False): print(" > Using Language weighted sampler") - sampler = get_speaker_weighted_sampler(dataset.items) + sampler = get_speaker_weighted_sampler(dataset.samples) loader = DataLoader( dataset, diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index a4f41be5..78b535a0 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -98,6 +98,15 @@ class LanguageManager: """ self._save_json(file_path, self.language_id_mapping) + @staticmethod + def init_from_config(config: Coqpit) -> "LanguageManager": + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + return LanguageManager(config=config) + def _set_file_path(path): """Find the language_ids.json under the given path or the above it. diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index ba48f27c..99d653e6 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -9,7 +9,7 @@ import torch from coqpit import Coqpit from torch.utils.data.sampler import WeightedRandomSampler -from TTS.config import load_config +from TTS.config import get_from_config_or_model_args_with_default, load_config from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model from TTS.utils.audio import AudioProcessor @@ -331,19 +331,27 @@ class SpeakerManager: SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding: + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): if samples: speaker_manager = SpeakerManager(data_items=samples) - if config.get("speaker_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=config.speaker_file) - if config.get("speakers_file", None): - speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) - if hasattr(config, "use_d_vector_file") and config.use_d_vector_file: - if config.get("speakers_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=config.speaker_file) - if config.get("d_vector_file", None): - speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) return speaker_manager diff --git a/recipes/vctk/fast_pitch/train_fast_pitch.py b/recipes/vctk/fast_pitch/train_fast_pitch.py index f7a2ef06..4d9cc10d 100644 --- a/recipes/vctk/fast_pitch/train_fast_pitch.py +++ b/recipes/vctk/fast_pitch/train_fast_pitch.py @@ -90,4 +90,3 @@ trainer = Trainer( # AND... 3,2,1... 🚀 trainer.fit() - diff --git a/recipes/vctk/fast_speech/train_fast_speech.py b/recipes/vctk/fast_speech/train_fast_speech.py index 853bbb54..1dcab982 100644 --- a/recipes/vctk/fast_speech/train_fast_speech.py +++ b/recipes/vctk/fast_speech/train_fast_speech.py @@ -87,4 +87,4 @@ trainer = Trainer( ) # AND... 3,2,1... 🚀 -trainer.fit() \ No newline at end of file +trainer.fit() diff --git a/recipes/vctk/glow_tts/train_glow_tts.py b/recipes/vctk/glow_tts/train_glow_tts.py index 30050ef5..e35e552d 100644 --- a/recipes/vctk/glow_tts/train_glow_tts.py +++ b/recipes/vctk/glow_tts/train_glow_tts.py @@ -87,4 +87,4 @@ trainer = Trainer( ) # AND... 3,2,1... 🚀 -trainer.fit() \ No newline at end of file +trainer.fit() diff --git a/tests/text_tests/test_phonemizer.py b/tests/text_tests/test_phonemizer.py index 512cc195..9b619f6e 100644 --- a/tests/text_tests/test_phonemizer.py +++ b/tests/text_tests/test_phonemizer.py @@ -109,10 +109,11 @@ class TestEspeakNgPhonemizer(unittest.TestCase): class TestGruutPhonemizer(unittest.TestCase): def setUp(self): self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False) - self.EXPECTED_PHONEMES = ["ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", - "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", - "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", - "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!" + self.EXPECTED_PHONEMES = [ + "ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ|ŋ", + "f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ", + "ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l", + "f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!", ] def test_phonemize(self):