diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 3bcd59a1..1c623f50 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -982,6 +982,7 @@ class Vits(BaseTTS): return aux_input["x_lengths"] return torch.tensor(x.shape[1:2]).to(x.device) + @torch.no_grad() def inference( self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} ): # pylint: disable=dangerous-default-value