From d0c27a9661f5b1b19d207391e13edc5ea742dda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 20 Feb 2022 11:50:42 +0100 Subject: [PATCH] Update synthesis.py --- TTS/tts/utils/synthesis.py | 7 ++++--- TTS/utils/training.py | 28 ---------------------------- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index e2d9c113..377f32de 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -193,14 +193,15 @@ def synthesis( # convert outputs to numpy # plot results wav = None - if hasattr(model, "END2END") and model.END2END: - wav = model_outputs.squeeze(0) - else: + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] if use_griffin_lim: wav = inv_spectrogram(model_outputs, model.ap, CONFIG) # trim silence if do_trim_silence: wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs return_dict = { "wav": wav, "alignments": alignments, diff --git a/TTS/utils/training.py b/TTS/utils/training.py index e69fb2b4..b51f55e9 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -42,31 +42,3 @@ def gradual_training_scheduler(global_step, config): if global_step * num_gpus >= values[0]: new_values = values return new_values[1], new_values[2] - - -def lr_decay(init_lr, global_step, warmup_steps): - r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py - It is only being used by the Speaker Encoder trainer.""" - warmup_steps = float(warmup_steps) - step = global_step + 1.0 - lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5) - return lr - - -# pylint: disable=dangerous-default-value -def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): - """ - Skip biases, BatchNorm parameters, rnns. - and attention projection layer v - """ - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): - no_decay.append(param) - else: - decay.append(param) - return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]