adapt align_tts and model name handling

This commit is contained in:
Eren Gölge 2021-03-03 15:43:05 +01:00 committed by Eren Gölge
parent aa29f5b199
commit 844e8e0ed4
1 changed files with 18 additions and 6 deletions

View File

@ -41,7 +41,9 @@ def sequence_mask(sequence_length, max_len=None):
def to_camel(text):
text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
text = text.replace('Tts', 'TTS')
return text
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
@ -132,13 +134,23 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
elif c.model.lower() == "align_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'],
encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
return model
def is_tacotron(c):
return not c['model'] in ['speedy_speech', 'glow_tts']
return 'tacotron' in c['model'].lower()
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str)
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str)
@ -195,7 +207,7 @@ def check_config_tts(c):
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower == "speedy_speech":
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
@ -239,7 +251,7 @@ def check_config_tts(c):
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models
if c['model'].lower == "speedy_speech":
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('positional_encoding', c, restricted=True, val_type=type)
check_argument('encoder_type', c, restricted=True, val_type=str)
check_argument('encoder_params', c, restricted=True, val_type=dict)
@ -289,4 +301,4 @@ def check_config_tts(c):
check_argument('name', dataset_entry, restricted=True, val_type=str)
check_argument('path', dataset_entry, restricted=True, val_type=str)
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)