diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 2e488e5c..182e58fb 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -321,7 +321,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) ############################## # generator pass - y_hat = model_G(c_G) + y_hat = model_G(c_G)[:, :, :y_G.size(2)] y_hat_sub = None y_G_sub = None @@ -373,7 +373,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): - y_hat = model_G(c_G) + y_hat = model_G(c_G)[:, :, :y_G.size(2)] # PQMF formatting if y_hat.shape[1] > 1: