mirror of https://github.com/coqui-ai/TTS.git
fix tts commandline for tortoise
This commit is contained in:
parent
771bff8c1f
commit
4fa4defd64
|
@ -436,6 +436,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
source_wav=args.source_wav,
|
source_wav=args.source_wav,
|
||||||
target_wav=args.target_wav,
|
target_wav=args.target_wav,
|
||||||
)
|
)
|
||||||
|
elif model_dir is not None:
|
||||||
|
wav = synthesizer.tts(args.text)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
print(" > Saving output to {}".format(args.out_path))
|
print(" > Saving output to {}".format(args.out_path))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
@ -18,7 +19,6 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
|
||||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||||
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
||||||
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.tortoise.utils import get_model_path
|
|
||||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
@ -147,7 +147,7 @@ def do_spectrogram_diffusion(
|
||||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||||
|
|
||||||
|
|
||||||
def classify_audio_clip(clip):
|
def classify_audio_clip(clip, model_dir):
|
||||||
"""
|
"""
|
||||||
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
||||||
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
||||||
|
@ -167,7 +167,7 @@ def classify_audio_clip(clip):
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
distribute_zero_label=False,
|
distribute_zero_label=False,
|
||||||
)
|
)
|
||||||
classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
|
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
|
||||||
clip = clip.cpu().unsqueeze(0)
|
clip = clip.cpu().unsqueeze(0)
|
||||||
results = F.softmax(classifier(clip), dim=-1)
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
@ -315,11 +315,11 @@ class Tortoise(BaseTTS):
|
||||||
.cpu()
|
.cpu()
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", self.models_dir)
|
ar_path = self.args.ar_checkpoint or os.path.join(self.models_dir, "autoregressive.pth")
|
||||||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||||
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
|
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
|
||||||
|
|
||||||
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", self.models_dir)
|
diff_path = self.args.diff_checkpoint or os.path.join(self.models_dir, "diffusion_decoder.pth")
|
||||||
self.diffusion = (
|
self.diffusion = (
|
||||||
DiffusionTts(
|
DiffusionTts(
|
||||||
model_channels=self.args.diff_model_channels,
|
model_channels=self.args.diff_model_channels,
|
||||||
|
@ -357,14 +357,14 @@ class Tortoise(BaseTTS):
|
||||||
.cpu()
|
.cpu()
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", self.models_dir)
|
clvp_path = self.args.clvp_checkpoint or os.path.join(self.models_dir, "clvp2.pth")
|
||||||
self.clvp.load_state_dict(torch.load(clvp_path))
|
self.clvp.load_state_dict(torch.load(clvp_path))
|
||||||
|
|
||||||
self.vocoder = vocoder.value.constructor().cpu()
|
self.vocoder = vocoder.value.constructor().cpu()
|
||||||
self.vocoder.load_state_dict(
|
self.vocoder.load_state_dict(
|
||||||
vocoder.value.optionally_index(
|
vocoder.value.optionally_index(
|
||||||
torch.load(
|
torch.load(
|
||||||
get_model_path(vocoder.value.model_path, self.models_dir),
|
os.path.join(self.models_dir, vocoder.value.model_path),
|
||||||
map_location=torch.device("cpu"),
|
map_location=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -472,14 +472,14 @@ class Tortoise(BaseTTS):
|
||||||
self.rlg_auto = RandomLatentConverter(1024).eval()
|
self.rlg_auto = RandomLatentConverter(1024).eval()
|
||||||
self.rlg_auto.load_state_dict(
|
self.rlg_auto.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
get_model_path("rlg_auto.pth", self.models_dir),
|
os.path.join(self.models_dir, "rlg_auto.pth"),
|
||||||
map_location=torch.device("cpu"),
|
map_location=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
self.rlg_diffusion = RandomLatentConverter(2048).eval()
|
||||||
self.rlg_diffusion.load_state_dict(
|
self.rlg_diffusion.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
get_model_path("rlg_diffuser.pth", self.models_dir),
|
os.path.join(self.models_dir, "rlg_diffuser.pth"),
|
||||||
map_location=torch.device("cpu"),
|
map_location=torch.device("cpu"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -555,6 +555,7 @@ class Tortoise(BaseTTS):
|
||||||
"diffusion_iterations": 400,
|
"diffusion_iterations": 400,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if hasattr(kwargs, "preset"):
|
||||||
settings.update(presets[kwargs["preset"]])
|
settings.update(presets[kwargs["preset"]])
|
||||||
kwargs.pop("preset")
|
kwargs.pop("preset")
|
||||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||||
|
|
Loading…
Reference in New Issue