fix tts commandline for tortoise

This commit is contained in:
manmay-nakhashi 2023-05-02 02:02:05 +05:30
parent 771bff8c1f
commit 4fa4defd64
2 changed files with 14 additions and 11 deletions

View File

@ -436,6 +436,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
source_wav=args.source_wav,
target_wav=args.target_wav,
)
elif model_dir is not None:
wav = synthesizer.tts(args.text)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -1,3 +1,4 @@
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
@ -18,7 +19,6 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.models.base_tts import BaseTTS
@ -147,7 +147,7 @@ def do_spectrogram_diffusion(
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
def classify_audio_clip(clip):
def classify_audio_clip(clip, model_dir):
"""
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
:param clip: torch tensor containing audio waveform data (get it from load_audio)
@ -167,7 +167,7 @@ def classify_audio_clip(clip):
kernel_size=5,
distribute_zero_label=False,
)
classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu")))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
@ -315,11 +315,11 @@ class Tortoise(BaseTTS):
.cpu()
.eval()
)
ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", self.models_dir)
ar_path = self.args.ar_checkpoint or os.path.join(self.models_dir, "autoregressive.pth")
self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(self.args.kv_cache)
diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", self.models_dir)
diff_path = self.args.diff_checkpoint or os.path.join(self.models_dir, "diffusion_decoder.pth")
self.diffusion = (
DiffusionTts(
model_channels=self.args.diff_model_channels,
@ -357,14 +357,14 @@ class Tortoise(BaseTTS):
.cpu()
.eval()
)
clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", self.models_dir)
clvp_path = self.args.clvp_checkpoint or os.path.join(self.models_dir, "clvp2.pth")
self.clvp.load_state_dict(torch.load(clvp_path))
self.vocoder = vocoder.value.constructor().cpu()
self.vocoder.load_state_dict(
vocoder.value.optionally_index(
torch.load(
get_model_path(vocoder.value.model_path, self.models_dir),
os.path.join(self.models_dir, vocoder.value.model_path),
map_location=torch.device("cpu"),
)
)
@ -472,14 +472,14 @@ class Tortoise(BaseTTS):
self.rlg_auto = RandomLatentConverter(1024).eval()
self.rlg_auto.load_state_dict(
torch.load(
get_model_path("rlg_auto.pth", self.models_dir),
os.path.join(self.models_dir, "rlg_auto.pth"),
map_location=torch.device("cpu"),
)
)
self.rlg_diffusion = RandomLatentConverter(2048).eval()
self.rlg_diffusion.load_state_dict(
torch.load(
get_model_path("rlg_diffuser.pth", self.models_dir),
os.path.join(self.models_dir, "rlg_diffuser.pth"),
map_location=torch.device("cpu"),
)
)
@ -555,8 +555,9 @@ class Tortoise(BaseTTS):
"diffusion_iterations": 400,
},
}
settings.update(presets[kwargs["preset"]])
kwargs.pop("preset")
if hasattr(kwargs, "preset"):
settings.update(presets[kwargs["preset"]])
kwargs.pop("preset")
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, **settings)