style fixes

This commit is contained in:
manmay-nakhashi 2023-04-22 23:35:18 +05:30
parent 9853bf59e1
commit 9575e33470
1 changed files with 21 additions and 24 deletions

View File

@ -37,6 +37,20 @@ def pad_or_truncate(t, length):
return tp
def deterministic_state(seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
@ -93,15 +107,13 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
"try breaking up your input text."
)
return codes
else:
codes[stop_token_indices] = 83
codes[stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[stm:] = 83
if stm - 3 < codes.shape[0]:
codes[-3] = 45
codes[-2] = 45
codes[-1] = 248
return codes
@ -170,13 +182,14 @@ def pick_best_batch_size_for_gpu():
if torch.cuda.is_available():
_, available = torch.cuda.mem_get_info()
availableGb = available / (1024**3)
batch_size = 1
if availableGb > 14:
return 16
batch_size = 16
elif availableGb > 10:
return 8
batch_size = 8
elif availableGb > 7:
return 4
return 1
batch_size = 4
return batch_size
class TextToSpeech:
@ -184,9 +197,6 @@ class TextToSpeech:
Main entry point into Tortoise.
"""
def _config(self):
raise RuntimeError("This is depreciated")
def __init__(
self,
autoregressive_batch_size=None,
@ -344,19 +354,6 @@ class TextToSpeech:
)
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
def deterministic_state(self, seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def get_conditioning_latents(
self,
voice_samples,
@ -590,7 +587,7 @@ class TextToSpeech:
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
deterministic_seed = deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.