From 57ef53bef315e05e686c9de5d42d462aa45f4e79 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 21 Dec 2020 12:29:07 +0100 Subject: [PATCH] update argumnet check for non tacotron models --- TTS/tts/utils/generic_utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 741e8e5c..c5f966c6 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -126,15 +126,21 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): hidden_channels_enc=192, hidden_channels_dec=192, use_encoder_prenet=True, - rel_attn_window_size=4, external_speaker_embedding_dim=speaker_embedding_dim) + elif c.model.lower() == "speedy_speech": + model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio['num_mels'], + hidden_channels=128, + encoder_type=c['encoder_type'], + decoder_residual_conv_bn_params=c['decoder_residual_conv_bn_params'], + c_in_channels=0) return model def is_tacotron(c): - return False if 'glow_tts' in c['model'] else True + return False if c['model'] in ['speedy_speech', 'glow_tts'] else True def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str) + check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech'], restricted=True, val_type=str) check_argument('run_name', c, restricted=True, val_type=str) check_argument('run_description', c, val_type=str)