mirror of https://github.com/coqui-ai/TTS.git
style fixes
This commit is contained in:
parent
0542d18a88
commit
9853bf59e1
|
@ -252,12 +252,7 @@ class BaseTacotron(BaseTTS):
|
|||
|
||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||
"""Capacitron Variational Autoencoder"""
|
||||
(
|
||||
VAE_outputs,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
) = self.capacitron_vae_layer(
|
||||
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
||||
reference_mel_info,
|
||||
text_info,
|
||||
speaker_embedding, # pylint: disable=not-callable
|
||||
|
@ -302,4 +297,4 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||
|
|
|
@ -29,12 +29,12 @@ def pad_or_truncate(t, length):
|
|||
"""
|
||||
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
||||
"""
|
||||
tp = t[..., :length]
|
||||
if t.shape[-1] == length:
|
||||
return t
|
||||
tp = t
|
||||
elif t.shape[-1] < length:
|
||||
return F.pad(t, (0, length - t.shape[-1]))
|
||||
else:
|
||||
return t[..., :length]
|
||||
tp = F.pad(t, (0, length - t.shape[-1]))
|
||||
return tp
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(
|
||||
|
@ -186,19 +186,12 @@ class TextToSpeech:
|
|||
|
||||
def _config(self):
|
||||
raise RuntimeError("This is depreciated")
|
||||
return {
|
||||
"high_vram": self.high_vram,
|
||||
"models_dir": self.models_dir,
|
||||
"kv_cache": self.autoregressive.inference_model.kv_cache,
|
||||
"ar_checkpoint": self.ar_checkpoint,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
autoregressive_batch_size=None,
|
||||
models_dir=MODELS_DIR,
|
||||
enable_redaction=True,
|
||||
device=None,
|
||||
high_vram=False,
|
||||
kv_cache=True,
|
||||
ar_checkpoint=None,
|
||||
|
@ -351,6 +344,19 @@ class TextToSpeech:
|
|||
)
|
||||
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
||||
|
||||
def deterministic_state(self, seed=None):
|
||||
"""
|
||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||
reproduced.
|
||||
"""
|
||||
seed = int(time()) if seed is None else seed
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
|
||||
return seed
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
voice_samples,
|
||||
|
@ -424,8 +430,7 @@ class TextToSpeech:
|
|||
|
||||
if return_mels:
|
||||
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
||||
else:
|
||||
return auto_latent, diffusion_latent
|
||||
return auto_latent, diffusion_latent
|
||||
|
||||
def get_random_conditioning_latents(self):
|
||||
# Lazy-load the RLG models.
|
||||
|
@ -723,13 +728,13 @@ class TextToSpeech:
|
|||
|
||||
# Find the first occurrence of the "calm" token and trim the codes to that.
|
||||
ctokens = 0
|
||||
for k in range(codes.shape[-1]):
|
||||
if codes[0, k] == calm_token:
|
||||
for code in range(codes.shape[-1]):
|
||||
if codes[0, code] == calm_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
latents = latents[:, :code]
|
||||
break
|
||||
with self.temporary_cuda(self.diffusion) as diffusion:
|
||||
mel = do_spectrogram_diffusion(
|
||||
|
@ -763,18 +768,4 @@ class TextToSpeech:
|
|||
voice_samples,
|
||||
conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return res
|
||||
|
||||
def deterministic_state(self, seed=None):
|
||||
"""
|
||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||
reproduced.
|
||||
"""
|
||||
seed = int(time()) if seed is None else seed
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
|
||||
return seed
|
||||
return res
|
||||
|
|
Loading…
Reference in New Issue