mirror of https://github.com/coqui-ai/TTS.git
Add inference_mode
This commit is contained in:
parent
0d36dcfd81
commit
1d752b7f48
|
@ -363,6 +363,7 @@ class Xtts(BaseTTS):
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||||
|
|
||||||
|
@ -377,6 +378,7 @@ class Xtts(BaseTTS):
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def get_diffusion_cond_latents(
|
def get_diffusion_cond_latents(
|
||||||
self,
|
self,
|
||||||
audio_path,
|
audio_path,
|
||||||
|
@ -399,6 +401,7 @@ class Xtts(BaseTTS):
|
||||||
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||||
return diffusion_latent
|
return diffusion_latent
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(
|
def get_speaker_embedding(
|
||||||
self,
|
self,
|
||||||
audio_path
|
audio_path
|
||||||
|
@ -468,7 +471,7 @@ class Xtts(BaseTTS):
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
return self.full_inference(text, ref_audio_path, language, **settings)
|
return self.full_inference(text, ref_audio_path, language, **settings)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def full_inference(
|
def full_inference(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
|
@ -569,7 +572,7 @@ class Xtts(BaseTTS):
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def inference(
|
def inference(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
|
@ -675,6 +678,7 @@ class Xtts(BaseTTS):
|
||||||
wav_gen_prev = wav_gen
|
wav_gen_prev = wav_gen
|
||||||
return wav_chunk, wav_gen_prev, wav_overlap
|
return wav_chunk, wav_gen_prev, wav_overlap
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def inference_stream(
|
def inference_stream(
|
||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
|
|
Loading…
Reference in New Issue