mirror of https://github.com/coqui-ai/TTS.git
style fixes
This commit is contained in:
parent
9853bf59e1
commit
9575e33470
|
@ -37,6 +37,20 @@ def pad_or_truncate(t, length):
|
|||
return tp
|
||||
|
||||
|
||||
def deterministic_state(seed=None):
|
||||
"""
|
||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||
reproduced.
|
||||
"""
|
||||
seed = int(time()) if seed is None else seed
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(
|
||||
trained_diffusion_steps=4000,
|
||||
desired_diffusion_steps=200,
|
||||
|
@ -93,15 +107,13 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
|
|||
"try breaking up your input text."
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
codes[stop_token_indices] = 83
|
||||
codes[stop_token_indices] = 83
|
||||
stm = stop_token_indices.min().item()
|
||||
codes[stm:] = 83
|
||||
if stm - 3 < codes.shape[0]:
|
||||
codes[-3] = 45
|
||||
codes[-2] = 45
|
||||
codes[-1] = 248
|
||||
|
||||
return codes
|
||||
|
||||
|
||||
|
@ -170,13 +182,14 @@ def pick_best_batch_size_for_gpu():
|
|||
if torch.cuda.is_available():
|
||||
_, available = torch.cuda.mem_get_info()
|
||||
availableGb = available / (1024**3)
|
||||
batch_size = 1
|
||||
if availableGb > 14:
|
||||
return 16
|
||||
batch_size = 16
|
||||
elif availableGb > 10:
|
||||
return 8
|
||||
batch_size = 8
|
||||
elif availableGb > 7:
|
||||
return 4
|
||||
return 1
|
||||
batch_size = 4
|
||||
return batch_size
|
||||
|
||||
|
||||
class TextToSpeech:
|
||||
|
@ -184,9 +197,6 @@ class TextToSpeech:
|
|||
Main entry point into Tortoise.
|
||||
"""
|
||||
|
||||
def _config(self):
|
||||
raise RuntimeError("This is depreciated")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
autoregressive_batch_size=None,
|
||||
|
@ -344,19 +354,6 @@ class TextToSpeech:
|
|||
)
|
||||
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
||||
|
||||
def deterministic_state(self, seed=None):
|
||||
"""
|
||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||
reproduced.
|
||||
"""
|
||||
seed = int(time()) if seed is None else seed
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
|
||||
return seed
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
voice_samples,
|
||||
|
@ -590,7 +587,7 @@ class TextToSpeech:
|
|||
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
||||
deterministic_seed = deterministic_state(seed=use_deterministic_seed)
|
||||
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
|
|
Loading…
Reference in New Issue