From 1d752b7f48796f1b503c2e8bd99d09702d64a31a Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Wed, 4 Oct 2023 09:36:20 -0300 Subject: [PATCH] Add inference_mode --- TTS/tts/models/xtts.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 5b1bc3fd..130d0113 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -363,6 +363,7 @@ class Xtts(BaseTTS): def device(self): return next(self.parameters()).device + @torch.inference_mode() def get_gpt_cond_latents(self, audio_path: str, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. @@ -377,6 +378,7 @@ class Xtts(BaseTTS): cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) return cond_latent.transpose(1, 2) + @torch.inference_mode() def get_diffusion_cond_latents( self, audio_path, @@ -399,6 +401,7 @@ class Xtts(BaseTTS): diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds) return diffusion_latent + @torch.inference_mode() def get_speaker_embedding( self, audio_path @@ -468,7 +471,7 @@ class Xtts(BaseTTS): settings.update(kwargs) # allow overriding of preset settings with kwargs return self.full_inference(text, ref_audio_path, language, **settings) - @torch.no_grad() + @torch.inference_mode() def full_inference( self, text, @@ -569,7 +572,7 @@ class Xtts(BaseTTS): **hf_generate_kwargs, ) - @torch.no_grad() + @torch.inference_mode() def inference( self, text, @@ -675,6 +678,7 @@ class Xtts(BaseTTS): wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap + @torch.inference_mode() def inference_stream( self, text,