mirror of https://github.com/coqui-ai/TTS.git
Add inference_mode
This commit is contained in:
parent
0d36dcfd81
commit
1d752b7f48
|
@ -363,6 +363,7 @@ class Xtts(BaseTTS):
|
|||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
|
@ -377,6 +378,7 @@ class Xtts(BaseTTS):
|
|||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_diffusion_cond_latents(
|
||||
self,
|
||||
audio_path,
|
||||
|
@ -399,6 +401,7 @@ class Xtts(BaseTTS):
|
|||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(
|
||||
self,
|
||||
audio_path
|
||||
|
@ -468,7 +471,7 @@ class Xtts(BaseTTS):
|
|||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def full_inference(
|
||||
self,
|
||||
text,
|
||||
|
@ -569,7 +572,7 @@ class Xtts(BaseTTS):
|
|||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
|
@ -675,6 +678,7 @@ class Xtts(BaseTTS):
|
|||
wav_gen_prev = wav_gen
|
||||
return wav_chunk, wav_gen_prev, wav_overlap
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_stream(
|
||||
self,
|
||||
text,
|
||||
|
|
Loading…
Reference in New Issue