style fixes

This commit is contained in:
manmay-nakhashi 2023-04-22 22:16:54 +05:30
parent 0542d18a88
commit 9853bf59e1
2 changed files with 24 additions and 38 deletions

View File

@ -252,12 +252,7 @@ class BaseTacotron(BaseTTS):
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(
VAE_outputs,
posterior_distribution,
prior_distribution,
capacitron_beta,
) = self.capacitron_vae_layer(
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable
@ -302,4 +297,4 @@ class BaseTacotron(BaseTTS):
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")
print(f"\n > Number of output frames: {self.decoder.r}")

View File

@ -29,12 +29,12 @@ def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
tp = t[..., :length]
if t.shape[-1] == length:
return t
tp = t
elif t.shape[-1] < length:
return F.pad(t, (0, length - t.shape[-1]))
else:
return t[..., :length]
tp = F.pad(t, (0, length - t.shape[-1]))
return tp
def load_discrete_vocoder_diffuser(
@ -186,19 +186,12 @@ class TextToSpeech:
def _config(self):
raise RuntimeError("This is depreciated")
return {
"high_vram": self.high_vram,
"models_dir": self.models_dir,
"kv_cache": self.autoregressive.inference_model.kv_cache,
"ar_checkpoint": self.ar_checkpoint,
}
def __init__(
self,
autoregressive_batch_size=None,
models_dir=MODELS_DIR,
enable_redaction=True,
device=None,
high_vram=False,
kv_cache=True,
ar_checkpoint=None,
@ -351,6 +344,19 @@ class TextToSpeech:
)
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
def deterministic_state(self, seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def get_conditioning_latents(
self,
voice_samples,
@ -424,8 +430,7 @@ class TextToSpeech:
if return_mels:
return auto_latent, diffusion_latent, auto_conds, diffusion_conds
else:
return auto_latent, diffusion_latent
return auto_latent, diffusion_latent
def get_random_conditioning_latents(self):
# Lazy-load the RLG models.
@ -723,13 +728,13 @@ class TextToSpeech:
# Find the first occurrence of the "calm" token and trim the codes to that.
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
for code in range(codes.shape[-1]):
if codes[0, code] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
latents = latents[:, :code]
break
with self.temporary_cuda(self.diffusion) as diffusion:
mel = do_spectrogram_diffusion(
@ -763,18 +768,4 @@ class TextToSpeech:
voice_samples,
conditioning_latents,
)
else:
return res
def deterministic_state(self, seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
return res