mirror of https://github.com/coqui-ai/TTS.git
style fixes
This commit is contained in:
parent
0542d18a88
commit
9853bf59e1
|
@ -252,12 +252,7 @@ class BaseTacotron(BaseTTS):
|
||||||
|
|
||||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||||
"""Capacitron Variational Autoencoder"""
|
"""Capacitron Variational Autoencoder"""
|
||||||
(
|
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
||||||
VAE_outputs,
|
|
||||||
posterior_distribution,
|
|
||||||
prior_distribution,
|
|
||||||
capacitron_beta,
|
|
||||||
) = self.capacitron_vae_layer(
|
|
||||||
reference_mel_info,
|
reference_mel_info,
|
||||||
text_info,
|
text_info,
|
||||||
speaker_embedding, # pylint: disable=not-callable
|
speaker_embedding, # pylint: disable=not-callable
|
||||||
|
@ -302,4 +297,4 @@ class BaseTacotron(BaseTTS):
|
||||||
self.decoder.set_r(r)
|
self.decoder.set_r(r)
|
||||||
if trainer.config.bidirectional_decoder:
|
if trainer.config.bidirectional_decoder:
|
||||||
trainer.model.decoder_backward.set_r(r)
|
trainer.model.decoder_backward.set_r(r)
|
||||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
print(f"\n > Number of output frames: {self.decoder.r}")
|
||||||
|
|
|
@ -29,12 +29,12 @@ def pad_or_truncate(t, length):
|
||||||
"""
|
"""
|
||||||
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
||||||
"""
|
"""
|
||||||
|
tp = t[..., :length]
|
||||||
if t.shape[-1] == length:
|
if t.shape[-1] == length:
|
||||||
return t
|
tp = t
|
||||||
elif t.shape[-1] < length:
|
elif t.shape[-1] < length:
|
||||||
return F.pad(t, (0, length - t.shape[-1]))
|
tp = F.pad(t, (0, length - t.shape[-1]))
|
||||||
else:
|
return tp
|
||||||
return t[..., :length]
|
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(
|
def load_discrete_vocoder_diffuser(
|
||||||
|
@ -186,19 +186,12 @@ class TextToSpeech:
|
||||||
|
|
||||||
def _config(self):
|
def _config(self):
|
||||||
raise RuntimeError("This is depreciated")
|
raise RuntimeError("This is depreciated")
|
||||||
return {
|
|
||||||
"high_vram": self.high_vram,
|
|
||||||
"models_dir": self.models_dir,
|
|
||||||
"kv_cache": self.autoregressive.inference_model.kv_cache,
|
|
||||||
"ar_checkpoint": self.ar_checkpoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
autoregressive_batch_size=None,
|
autoregressive_batch_size=None,
|
||||||
models_dir=MODELS_DIR,
|
models_dir=MODELS_DIR,
|
||||||
enable_redaction=True,
|
enable_redaction=True,
|
||||||
device=None,
|
|
||||||
high_vram=False,
|
high_vram=False,
|
||||||
kv_cache=True,
|
kv_cache=True,
|
||||||
ar_checkpoint=None,
|
ar_checkpoint=None,
|
||||||
|
@ -351,6 +344,19 @@ class TextToSpeech:
|
||||||
)
|
)
|
||||||
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
||||||
|
|
||||||
|
def deterministic_state(self, seed=None):
|
||||||
|
"""
|
||||||
|
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||||
|
reproduced.
|
||||||
|
"""
|
||||||
|
seed = int(time()) if seed is None else seed
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||||
|
# torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
self,
|
self,
|
||||||
voice_samples,
|
voice_samples,
|
||||||
|
@ -424,8 +430,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
if return_mels:
|
if return_mels:
|
||||||
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
|
||||||
else:
|
return auto_latent, diffusion_latent
|
||||||
return auto_latent, diffusion_latent
|
|
||||||
|
|
||||||
def get_random_conditioning_latents(self):
|
def get_random_conditioning_latents(self):
|
||||||
# Lazy-load the RLG models.
|
# Lazy-load the RLG models.
|
||||||
|
@ -723,13 +728,13 @@ class TextToSpeech:
|
||||||
|
|
||||||
# Find the first occurrence of the "calm" token and trim the codes to that.
|
# Find the first occurrence of the "calm" token and trim the codes to that.
|
||||||
ctokens = 0
|
ctokens = 0
|
||||||
for k in range(codes.shape[-1]):
|
for code in range(codes.shape[-1]):
|
||||||
if codes[0, k] == calm_token:
|
if codes[0, code] == calm_token:
|
||||||
ctokens += 1
|
ctokens += 1
|
||||||
else:
|
else:
|
||||||
ctokens = 0
|
ctokens = 0
|
||||||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||||
latents = latents[:, :k]
|
latents = latents[:, :code]
|
||||||
break
|
break
|
||||||
with self.temporary_cuda(self.diffusion) as diffusion:
|
with self.temporary_cuda(self.diffusion) as diffusion:
|
||||||
mel = do_spectrogram_diffusion(
|
mel = do_spectrogram_diffusion(
|
||||||
|
@ -763,18 +768,4 @@ class TextToSpeech:
|
||||||
voice_samples,
|
voice_samples,
|
||||||
conditioning_latents,
|
conditioning_latents,
|
||||||
)
|
)
|
||||||
else:
|
return res
|
||||||
return res
|
|
||||||
|
|
||||||
def deterministic_state(self, seed=None):
|
|
||||||
"""
|
|
||||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
|
||||||
reproduced.
|
|
||||||
"""
|
|
||||||
seed = int(time()) if seed is None else seed
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
|
||||||
# torch.use_deterministic_algorithms(True)
|
|
||||||
|
|
||||||
return seed
|
|
||||||
|
|
Loading…
Reference in New Issue