Fixes small compat. issues

This commit is contained in:
Eren Gölge 2022-01-07 15:38:08 +00:00
parent 131bc0cfc0
commit 5176ae9e53
10 changed files with 44 additions and 25 deletions

View File

@ -111,8 +111,8 @@ def load_tts_samples(
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for the duration predictor training
if d.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(d["meta_file_attn_mask"]))
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].update({"alignment_file": attn_file})

View File

@ -13,7 +13,7 @@ from TTS.utils.audio import AudioProcessor
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy('file_system')
torch.multiprocessing.set_sharing_strategy("file_system")
def _parse_sample(item):

View File

@ -291,7 +291,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic2", ignored_speakers=None):
"""https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"""
file_ext = 'flac'
file_ext = "flac"
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)

View File

@ -261,7 +261,7 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None
d_vector_mapping = None
# setup custom symbols if needed
# setup multi-lingual attributes
if hasattr(self, "language_manager"):
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
@ -290,6 +290,7 @@ class BaseTTS(BaseModel):
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer,
language_id_mapping=language_id_mapping,
)
# wait all the DDP process to be ready
@ -303,6 +304,7 @@ class BaseTTS(BaseModel):
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# Weighted samplers
# TODO: make this DDP amenable
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
@ -313,10 +315,10 @@ class BaseTTS(BaseModel):
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_language_weighted_sampler(dataset.items)
sampler = get_language_weighted_sampler(dataset.samples)
elif getattr(config, "use_speaker_weighted_sampler", False):
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)
sampler = get_speaker_weighted_sampler(dataset.samples)
loader = DataLoader(
dataset,

View File

@ -98,6 +98,15 @@ class LanguageManager:
"""
self._save_json(file_path, self.language_id_mapping)
@staticmethod
def init_from_config(config: Coqpit) -> "LanguageManager":
"""Initialize the language manager from a Coqpit config.
Args:
config (Coqpit): Coqpit config.
"""
return LanguageManager(config=config)
def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it.

View File

@ -9,7 +9,7 @@ import torch
from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler
from TTS.config import load_config
from TTS.config import get_from_config_or_model_args_with_default, load_config
from TTS.speaker_encoder.utils.generic_utils import setup_speaker_encoder_model
from TTS.utils.audio import AudioProcessor
@ -331,19 +331,27 @@ class SpeakerManager:
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if hasattr(config, "use_speaker_embedding") and config.use_speaker_embedding:
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if config.get("speaker_file", None):
speaker_manager = SpeakerManager(speaker_id_file_path=config.speaker_file)
if config.get("speakers_file", None):
speaker_manager = SpeakerManager(speaker_id_file_path=config.speakers_file)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None)
)
if hasattr(config, "use_d_vector_file") and config.use_d_vector_file:
if config.get("speakers_file", None):
speaker_manager = SpeakerManager(d_vectors_file_path=config.speaker_file)
if config.get("d_vector_file", None):
speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file)
if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False):
if get_from_config_or_model_args_with_default(config, "speakers_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None)
)
if get_from_config_or_model_args_with_default(config, "d_vector_file", None):
speaker_manager = SpeakerManager(
d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None)
)
return speaker_manager

View File

@ -90,4 +90,3 @@ trainer = Trainer(
# AND... 3,2,1... 🚀
trainer.fit()

View File

@ -87,4 +87,4 @@ trainer = Trainer(
)
# AND... 3,2,1... 🚀
trainer.fit()
trainer.fit()

View File

@ -87,4 +87,4 @@ trainer = Trainer(
)
# AND... 3,2,1... 🚀
trainer.fit()
trainer.fit()

View File

@ -109,10 +109,11 @@ class TestEspeakNgPhonemizer(unittest.TestCase):
class TestGruutPhonemizer(unittest.TestCase):
def setUp(self):
self.phonemizer = Gruut(language="en-us", use_espeak_phonemes=True, keep_stress=False)
self.EXPECTED_PHONEMES = ["ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ",
"f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ",
"ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l",
"f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!"
self.EXPECTED_PHONEMES = [
"ɹ|i|ː|s|ə|n|t| ɹ|ᵻ|s|ɜ|ː|t|ʃ| æ|ɾ| h|ɑ|ː|ɹ|v|ɚ|d| h|ɐ|z| ʃ|o|ʊ|n| m|ɛ|d|ᵻ|t|e|ɪ|ɾ|ɪ",
"f|ɔ|ː|ɹ| æ|z| l|ɪ|ɾ|ə|l| æ|z| e|ɪ|t| w|i|ː|k|s| k|æ|ŋ| æ|k|t|ʃ|u|ː|ə|l|i| ɪ|ŋ|k|ɹ|i|ː|s, ð|ə| ɡ|ɹ|e|ɪ| m|æ|ɾ|ɚ",
"ɪ|n| ð|ə| p|ɑ|ː|ɹ|t|s| ʌ|v| ð|ə| b|ɹ|e|ɪ|n| ɹ|ᵻ|s|p|ɑ|ː|n|s|ᵻ|b|ə|l",
"f|ɔ|ː|ɹ| ɪ|m|o|ʊ|ʃ|ə|n|ə|l| ɹ|ɛ|ɡ|j|ʊ|l|e|ɪ|ʃ|ə|n| æ|n|d| l|ɜ|ː|n|ɪ|ŋ!",
]
def test_phonemize(self):