mirror of https://github.com/coqui-ai/TTS.git
adapt align_tts and model name handling
This commit is contained in:
parent
aa29f5b199
commit
844e8e0ed4
|
@ -41,7 +41,9 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
text = text.replace('Tts', 'TTS')
|
||||
return text
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
|
@ -132,13 +134,23 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
decoder_type=c['decoder_type'],
|
||||
decoder_params=c['decoder_params'],
|
||||
c_in_channels=0)
|
||||
elif c.model.lower() == "align_tts":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio['num_mels'],
|
||||
hidden_channels=c['hidden_channels'],
|
||||
positional_encoding=c['positional_encoding'],
|
||||
encoder_type=c['encoder_type'],
|
||||
encoder_params=c['encoder_params'],
|
||||
decoder_type=c['decoder_type'],
|
||||
decoder_params=c['decoder_params'],
|
||||
c_in_channels=0)
|
||||
return model
|
||||
|
||||
def is_tacotron(c):
|
||||
return not c['model'] in ['speedy_speech', 'glow_tts']
|
||||
return 'tacotron' in c['model'].lower()
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str)
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
|
||||
|
@ -195,7 +207,7 @@ def check_config_tts(c):
|
|||
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
if c['model'].lower == "speedy_speech":
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
|
@ -239,7 +251,7 @@ def check_config_tts(c):
|
|||
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# Model Parameters for non-tacotron models
|
||||
if c['model'].lower == "speedy_speech":
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('positional_encoding', c, restricted=True, val_type=type)
|
||||
check_argument('encoder_type', c, restricted=True, val_type=str)
|
||||
check_argument('encoder_params', c, restricted=True, val_type=dict)
|
||||
|
@ -289,4 +301,4 @@ def check_config_tts(c):
|
|||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
Loading…
Reference in New Issue