mirror of https://github.com/coqui-ai/TTS.git
fix tts commandline for tortoise
This commit is contained in:
parent
771bff8c1f
commit
4fa4defd64
|
@ -436,6 +436,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
source_wav=args.source_wav,
|
||||
target_wav=args.target_wav,
|
||||
)
|
||||
elif model_dir is not None:
|
||||
wav = synthesizer.tts(args.text)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
@ -18,7 +19,6 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
|
|||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
||||
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.tortoise.utils import get_model_path
|
||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
|
@ -147,7 +147,7 @@ def do_spectrogram_diffusion(
|
|||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||
|
||||
|
||||
def classify_audio_clip(clip):
|
||||
def classify_audio_clip(clip, model_dir):
|
||||
"""
|
||||
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
||||
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
||||
|
@ -167,7 +167,7 @@ def classify_audio_clip(clip):
|
|||
kernel_size=5,
|
||||
distribute_zero_label=False,
|
||||
)
|
||||
classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
|
||||
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
|
||||
clip = clip.cpu().unsqueeze(0)
|
||||
results = F.softmax(classifier(clip), dim=-1)
|
||||
return results[0][0]
|
||||
|
@ -315,11 +315,11 @@ class Tortoise(BaseTTS):
|
|||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", self.models_dir)
|
||||
ar_path = self.args.ar_checkpoint or os.path.join(self.models_dir, "autoregressive.pth")
|
||||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
|
||||
|
||||
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", self.models_dir)
|
||||
diff_path = self.args.diff_checkpoint or os.path.join(self.models_dir, "diffusion_decoder.pth")
|
||||
self.diffusion = (
|
||||
DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
|
@ -357,14 +357,14 @@ class Tortoise(BaseTTS):
|
|||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", self.models_dir)
|
||||
clvp_path = self.args.clvp_checkpoint or os.path.join(self.models_dir, "clvp2.pth")
|
||||
self.clvp.load_state_dict(torch.load(clvp_path))
|
||||
|
||||
self.vocoder = vocoder.value.constructor().cpu()
|
||||
self.vocoder.load_state_dict(
|
||||
vocoder.value.optionally_index(
|
||||
torch.load(
|
||||
get_model_path(vocoder.value.model_path, self.models_dir),
|
||||
os.path.join(self.models_dir, vocoder.value.model_path),
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
)
|
||||
|
@ -472,14 +472,14 @@ class Tortoise(BaseTTS):
|
|||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||
self.rlg_auto.load_state_dict(
|
||||
torch.load(
|
||||
get_model_path("rlg_auto.pth", self.models_dir),
|
||||
os.path.join(self.models_dir, "rlg_auto.pth"),
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
)
|
||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||
self.rlg_diffusion.load_state_dict(
|
||||
torch.load(
|
||||
get_model_path("rlg_diffuser.pth", self.models_dir),
|
||||
os.path.join(self.models_dir, "rlg_diffuser.pth"),
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
)
|
||||
|
@ -555,8 +555,9 @@ class Tortoise(BaseTTS):
|
|||
"diffusion_iterations": 400,
|
||||
},
|
||||
}
|
||||
settings.update(presets[kwargs["preset"]])
|
||||
kwargs.pop("preset")
|
||||
if hasattr(kwargs, "preset"):
|
||||
settings.update(presets[kwargs["preset"]])
|
||||
kwargs.pop("preset")
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.inference(text, **settings)
|
||||
|
||||
|
|
Loading…
Reference in New Issue