diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index b5ab409c..f93f062c 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -436,6 +436,8 @@ If you don't specify any models, then it uses LJSpeech based English model. source_wav=args.source_wav, target_wav=args.target_wav, ) + elif model_dir is not None: + wav = synthesizer.tts(args.text) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index c44c7ec5..5566d720 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -1,3 +1,4 @@ +import os import random from contextlib import contextmanager from dataclasses import dataclass @@ -18,7 +19,6 @@ from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_sc from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.tortoise.utils import get_model_path from TTS.tts.layers.tortoise.vocoder import VocConf from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.models.base_tts import BaseTTS @@ -147,7 +147,7 @@ def do_spectrogram_diffusion( return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] -def classify_audio_clip(clip): +def classify_audio_clip(clip, model_dir): """ Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. :param clip: torch tensor containing audio waveform data (get it from load_audio) @@ -167,7 +167,7 @@ def classify_audio_clip(clip): kernel_size=5, distribute_zero_label=False, ) - classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))) + classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] @@ -315,11 +315,11 @@ class Tortoise(BaseTTS): .cpu() .eval() ) - ar_path = self.args.ar_checkpoint or get_model_path("autoregressive.pth", self.models_dir) + ar_path = self.args.ar_checkpoint or os.path.join(self.models_dir, "autoregressive.pth") self.autoregressive.load_state_dict(torch.load(ar_path)) self.autoregressive.post_init_gpt2_config(self.args.kv_cache) - diff_path = self.args.diff_checkpoint or get_model_path("diffusion_decoder.pth", self.models_dir) + diff_path = self.args.diff_checkpoint or os.path.join(self.models_dir, "diffusion_decoder.pth") self.diffusion = ( DiffusionTts( model_channels=self.args.diff_model_channels, @@ -357,14 +357,14 @@ class Tortoise(BaseTTS): .cpu() .eval() ) - clvp_path = self.args.clvp_checkpoint or get_model_path("clvp2.pth", self.models_dir) + clvp_path = self.args.clvp_checkpoint or os.path.join(self.models_dir, "clvp2.pth") self.clvp.load_state_dict(torch.load(clvp_path)) self.vocoder = vocoder.value.constructor().cpu() self.vocoder.load_state_dict( vocoder.value.optionally_index( torch.load( - get_model_path(vocoder.value.model_path, self.models_dir), + os.path.join(self.models_dir, vocoder.value.model_path), map_location=torch.device("cpu"), ) ) @@ -472,14 +472,14 @@ class Tortoise(BaseTTS): self.rlg_auto = RandomLatentConverter(1024).eval() self.rlg_auto.load_state_dict( torch.load( - get_model_path("rlg_auto.pth", self.models_dir), + os.path.join(self.models_dir, "rlg_auto.pth"), map_location=torch.device("cpu"), ) ) self.rlg_diffusion = RandomLatentConverter(2048).eval() self.rlg_diffusion.load_state_dict( torch.load( - get_model_path("rlg_diffuser.pth", self.models_dir), + os.path.join(self.models_dir, "rlg_diffuser.pth"), map_location=torch.device("cpu"), ) ) @@ -555,8 +555,9 @@ class Tortoise(BaseTTS): "diffusion_iterations": 400, }, } - settings.update(presets[kwargs["preset"]]) - kwargs.pop("preset") + if hasattr(kwargs, "preset"): + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") settings.update(kwargs) # allow overriding of preset settings with kwargs return self.inference(text, **settings)