diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7807efc1..613e4eae 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1050,7 +1050,15 @@ class Vits(BaseTTS): o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) - outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p} + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } return outputs @torch.no_grad()