diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 208a76d5..98beb97a 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1727,7 +1727,8 @@ class Vits(BaseTTS): attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) - + + pred_avg_pitch_emb = None if self.args.use_pitch: _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g, pitch_transform=pitch_transform) x = x + pred_avg_pitch_emb