diff --git a/TTS/trainer.py b/TTS/trainer.py index 34d73874..d81132cf 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -351,6 +351,7 @@ class TrainerTTS: speaker_ids = None # compute durations from attention masks + durations = None if attn_mask is not None: durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) for idx, am in enumerate(attn_mask):