style fixes

This commit is contained in:
manmay-nakhashi 2023-04-22 23:35:18 +05:30
parent 9853bf59e1
commit 9575e33470
1 changed files with 21 additions and 24 deletions

View File

@ -37,6 +37,20 @@ def pad_or_truncate(t, length):
return tp return tp
def deterministic_state(seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def load_discrete_vocoder_diffuser( def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000, trained_diffusion_steps=4000,
desired_diffusion_steps=200, desired_diffusion_steps=200,
@ -93,7 +107,6 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
"try breaking up your input text." "try breaking up your input text."
) )
return codes return codes
else:
codes[stop_token_indices] = 83 codes[stop_token_indices] = 83
stm = stop_token_indices.min().item() stm = stop_token_indices.min().item()
codes[stm:] = 83 codes[stm:] = 83
@ -101,7 +114,6 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
codes[-3] = 45 codes[-3] = 45
codes[-2] = 45 codes[-2] = 45
codes[-1] = 248 codes[-1] = 248
return codes return codes
@ -170,13 +182,14 @@ def pick_best_batch_size_for_gpu():
if torch.cuda.is_available(): if torch.cuda.is_available():
_, available = torch.cuda.mem_get_info() _, available = torch.cuda.mem_get_info()
availableGb = available / (1024**3) availableGb = available / (1024**3)
batch_size = 1
if availableGb > 14: if availableGb > 14:
return 16 batch_size = 16
elif availableGb > 10: elif availableGb > 10:
return 8 batch_size = 8
elif availableGb > 7: elif availableGb > 7:
return 4 batch_size = 4
return 1 return batch_size
class TextToSpeech: class TextToSpeech:
@ -184,9 +197,6 @@ class TextToSpeech:
Main entry point into Tortoise. Main entry point into Tortoise.
""" """
def _config(self):
raise RuntimeError("This is depreciated")
def __init__( def __init__(
self, self,
autoregressive_batch_size=None, autoregressive_batch_size=None,
@ -344,19 +354,6 @@ class TextToSpeech:
) )
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir))) self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
def deterministic_state(self, seed=None):
"""
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
reproduced.
"""
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
# torch.use_deterministic_algorithms(True)
return seed
def get_conditioning_latents( def get_conditioning_latents(
self, self,
voice_samples, voice_samples,
@ -590,7 +587,7 @@ class TextToSpeech:
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.