Add inference_mode

This commit is contained in:
WeberJulian 2023-10-04 09:36:20 -03:00
parent 0d36dcfd81
commit 1d752b7f48
1 changed files with 6 additions and 2 deletions

View File

@ -363,6 +363,7 @@ class Xtts(BaseTTS):
def device(self):
return next(self.parameters()).device
@torch.inference_mode()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
@ -377,6 +378,7 @@ class Xtts(BaseTTS):
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
return cond_latent.transpose(1, 2)
@torch.inference_mode()
def get_diffusion_cond_latents(
self,
audio_path,
@ -399,6 +401,7 @@ class Xtts(BaseTTS):
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent
@torch.inference_mode()
def get_speaker_embedding(
self,
audio_path
@ -468,7 +471,7 @@ class Xtts(BaseTTS):
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.full_inference(text, ref_audio_path, language, **settings)
@torch.no_grad()
@torch.inference_mode()
def full_inference(
self,
text,
@ -569,7 +572,7 @@ class Xtts(BaseTTS):
**hf_generate_kwargs,
)
@torch.no_grad()
@torch.inference_mode()
def inference(
self,
text,
@ -675,6 +678,7 @@ class Xtts(BaseTTS):
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
@torch.inference_mode()
def inference_stream(
self,
text,