mirror of https://github.com/coqui-ai/TTS.git
adapt align_tts and model name handling
This commit is contained in:
parent
aa29f5b199
commit
844e8e0ed4
|
@ -41,7 +41,9 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||||
|
text = text.replace('Tts', 'TTS')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
|
@ -132,13 +134,23 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
decoder_type=c['decoder_type'],
|
decoder_type=c['decoder_type'],
|
||||||
decoder_params=c['decoder_params'],
|
decoder_params=c['decoder_params'],
|
||||||
c_in_channels=0)
|
c_in_channels=0)
|
||||||
|
elif c.model.lower() == "align_tts":
|
||||||
|
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||||
|
out_channels=c.audio['num_mels'],
|
||||||
|
hidden_channels=c['hidden_channels'],
|
||||||
|
positional_encoding=c['positional_encoding'],
|
||||||
|
encoder_type=c['encoder_type'],
|
||||||
|
encoder_params=c['encoder_params'],
|
||||||
|
decoder_type=c['decoder_type'],
|
||||||
|
decoder_params=c['decoder_params'],
|
||||||
|
c_in_channels=0)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def is_tacotron(c):
|
def is_tacotron(c):
|
||||||
return not c['model'] in ['speedy_speech', 'glow_tts']
|
return 'tacotron' in c['model'].lower()
|
||||||
|
|
||||||
def check_config_tts(c):
|
def check_config_tts(c):
|
||||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str)
|
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
|
||||||
check_argument('run_name', c, restricted=True, val_type=str)
|
check_argument('run_name', c, restricted=True, val_type=str)
|
||||||
check_argument('run_description', c, val_type=str)
|
check_argument('run_description', c, val_type=str)
|
||||||
|
|
||||||
|
@ -195,7 +207,7 @@ def check_config_tts(c):
|
||||||
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
if c['model'].lower == "speedy_speech":
|
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||||
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||||
|
@ -239,7 +251,7 @@ def check_config_tts(c):
|
||||||
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||||
|
|
||||||
# Model Parameters for non-tacotron models
|
# Model Parameters for non-tacotron models
|
||||||
if c['model'].lower == "speedy_speech":
|
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||||
check_argument('positional_encoding', c, restricted=True, val_type=type)
|
check_argument('positional_encoding', c, restricted=True, val_type=type)
|
||||||
check_argument('encoder_type', c, restricted=True, val_type=str)
|
check_argument('encoder_type', c, restricted=True, val_type=str)
|
||||||
check_argument('encoder_params', c, restricted=True, val_type=dict)
|
check_argument('encoder_params', c, restricted=True, val_type=dict)
|
||||||
|
@ -289,4 +301,4 @@ def check_config_tts(c):
|
||||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
Loading…
Reference in New Issue