mirror of https://github.com/coqui-ai/TTS.git
style fixes
This commit is contained in:
parent
9853bf59e1
commit
9575e33470
|
@ -37,6 +37,20 @@ def pad_or_truncate(t, length):
|
||||||
return tp
|
return tp
|
||||||
|
|
||||||
|
|
||||||
|
def deterministic_state(seed=None):
|
||||||
|
"""
|
||||||
|
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
||||||
|
reproduced.
|
||||||
|
"""
|
||||||
|
seed = int(time()) if seed is None else seed
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
||||||
|
# torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def load_discrete_vocoder_diffuser(
|
def load_discrete_vocoder_diffuser(
|
||||||
trained_diffusion_steps=4000,
|
trained_diffusion_steps=4000,
|
||||||
desired_diffusion_steps=200,
|
desired_diffusion_steps=200,
|
||||||
|
@ -93,15 +107,13 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
|
||||||
"try breaking up your input text."
|
"try breaking up your input text."
|
||||||
)
|
)
|
||||||
return codes
|
return codes
|
||||||
else:
|
codes[stop_token_indices] = 83
|
||||||
codes[stop_token_indices] = 83
|
|
||||||
stm = stop_token_indices.min().item()
|
stm = stop_token_indices.min().item()
|
||||||
codes[stm:] = 83
|
codes[stm:] = 83
|
||||||
if stm - 3 < codes.shape[0]:
|
if stm - 3 < codes.shape[0]:
|
||||||
codes[-3] = 45
|
codes[-3] = 45
|
||||||
codes[-2] = 45
|
codes[-2] = 45
|
||||||
codes[-1] = 248
|
codes[-1] = 248
|
||||||
|
|
||||||
return codes
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,13 +182,14 @@ def pick_best_batch_size_for_gpu():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
_, available = torch.cuda.mem_get_info()
|
_, available = torch.cuda.mem_get_info()
|
||||||
availableGb = available / (1024**3)
|
availableGb = available / (1024**3)
|
||||||
|
batch_size = 1
|
||||||
if availableGb > 14:
|
if availableGb > 14:
|
||||||
return 16
|
batch_size = 16
|
||||||
elif availableGb > 10:
|
elif availableGb > 10:
|
||||||
return 8
|
batch_size = 8
|
||||||
elif availableGb > 7:
|
elif availableGb > 7:
|
||||||
return 4
|
batch_size = 4
|
||||||
return 1
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
|
@ -184,9 +197,6 @@ class TextToSpeech:
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _config(self):
|
|
||||||
raise RuntimeError("This is depreciated")
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
autoregressive_batch_size=None,
|
autoregressive_batch_size=None,
|
||||||
|
@ -344,19 +354,6 @@ class TextToSpeech:
|
||||||
)
|
)
|
||||||
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
||||||
|
|
||||||
def deterministic_state(self, seed=None):
|
|
||||||
"""
|
|
||||||
Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be
|
|
||||||
reproduced.
|
|
||||||
"""
|
|
||||||
seed = int(time()) if seed is None else seed
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
# Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary.
|
|
||||||
# torch.use_deterministic_algorithms(True)
|
|
||||||
|
|
||||||
return seed
|
|
||||||
|
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
self,
|
self,
|
||||||
voice_samples,
|
voice_samples,
|
||||||
|
@ -590,7 +587,7 @@ class TextToSpeech:
|
||||||
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
deterministic_seed = deterministic_state(seed=use_deterministic_seed)
|
||||||
|
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||||
|
|
Loading…
Reference in New Issue