From f01502a9dbe65eaa5d6b9f796e9673bd08ff929d Mon Sep 17 00:00:00 2001 From: Edresson Date: Tue, 27 Oct 2020 16:30:16 -0300 Subject: [PATCH] bug fix in glowTTS sythesize --- TTS/bin/synthesize.py | 10 +++++++--- TTS/tts/utils/generic_utils.py | 9 ++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index bb257548..64993754 100644 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -10,7 +10,7 @@ import time import torch -from TTS.tts.utils.generic_utils import setup_model +from TTS.tts.utils.generic_utils import setup_model, is_tacotron from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor @@ -125,7 +125,8 @@ if __name__ == "__main__": model.eval() if args.use_cuda: model.cuda() - model.decoder.set_r(cp['r']) + if is_tacotron(C): + model.decoder.set_r(cp['r']) # load vocoder model if args.vocoder_path != "": @@ -153,7 +154,10 @@ if __name__ == "__main__": args.speaker_fileid = None if args.gst_style is None: - gst_style = C.gst['gst_style_input'] + if is_tacotron(C): + gst_style = C.gst['gst_style_input'] + else: + gst_style = None else: # check if gst_style string is a dict, if is dict convert else use string try: diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 2361fa85..6f7949b2 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -28,7 +28,6 @@ def split_dataset(items): return items_eval, items return items[:eval_split_size], items[eval_split_size:] - # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def sequence_mask(sequence_length, max_len=None): if max_len is None: @@ -50,7 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) MyModel = getattr(MyModel, to_camel(c.model)) if c.model.lower() in "tacotron": - model = MyModel(num_chars=num_chars, + model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), num_speakers=num_speakers, r=c.r, postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), @@ -77,7 +76,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars, + model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), num_speakers=num_speakers, r=c.r, postnet_output_dim=c.audio['num_mels'], @@ -103,7 +102,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): ddc_r=c.ddc_r, speaker_embedding_dim=speaker_embedding_dim) elif c.model.lower() == "glow_tts": - model = MyModel(num_chars=num_chars, + model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), hidden_channels=192, filter_channels=768, filter_channels_dp=256, @@ -131,7 +130,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): return model def is_tacotron(c): - return False if c['model'] == 'glow_tts' else True + return False if 'glow_tts' in c['model'] else True def check_config_tts(c): check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)