From 844e8e0ed44f70030ba687ab94321497ecae58c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 3 Mar 2021 15:43:05 +0100 Subject: [PATCH] adapt align_tts and model name handling --- TTS/tts/utils/generic_utils.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 0d236fbc..c6a9c7ec 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -41,7 +41,9 @@ def sequence_mask(sequence_length, max_len=None): def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + text = text.replace('Tts', 'TTS') + return text def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): @@ -132,13 +134,23 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): decoder_type=c['decoder_type'], decoder_params=c['decoder_params'], c_in_channels=0) + elif c.model.lower() == "align_tts": + model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio['num_mels'], + hidden_channels=c['hidden_channels'], + positional_encoding=c['positional_encoding'], + encoder_type=c['encoder_type'], + encoder_params=c['encoder_params'], + decoder_type=c['decoder_type'], + decoder_params=c['decoder_params'], + c_in_channels=0) return model def is_tacotron(c): - return not c['model'] in ['speedy_speech', 'glow_tts'] + return 'tacotron' in c['model'].lower() def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str) + check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str) @@ -195,7 +207,7 @@ def check_config_tts(c): check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) - if c['model'].lower == "speedy_speech": + if c['model'].lower in ["speedy_speech", "align_tts"]: check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) @@ -239,7 +251,7 @@ def check_config_tts(c): check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) # Model Parameters for non-tacotron models - if c['model'].lower == "speedy_speech": + if c['model'].lower in ["speedy_speech", "align_tts"]: check_argument('positional_encoding', c, restricted=True, val_type=type) check_argument('encoder_type', c, restricted=True, val_type=str) check_argument('encoder_params', c, restricted=True, val_type=dict) @@ -289,4 +301,4 @@ def check_config_tts(c): check_argument('name', dataset_entry, restricted=True, val_type=str) check_argument('path', dataset_entry, restricted=True, val_type=str) check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) \ No newline at end of file