From f29882c0ab300c2a3f142b295abaaed7117ca620 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Sat, 6 May 2023 11:33:48 +0200 Subject: [PATCH] Add load from dir to synthesizer --- TTS/utils/synthesizer.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 8d143180..8b50e11d 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,3 +1,4 @@ +import os import time from typing import List @@ -8,7 +9,6 @@ import torch from TTS.config import load_config from TTS.tts.configs.tortoise_config import TortoiseConfig from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.models.tortoise import init_from_config # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import @@ -98,8 +98,7 @@ class Synthesizer(object): self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if model_dir: - self.tts_model = init_from_config(TortoiseConfig(model_dir=model_dir)) - self.tts_config = TortoiseConfig(model_dir=model_dir) + self._load_tts_from_dir(model_dir, use_cuda) self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @staticmethod @@ -134,6 +133,19 @@ class Synthesizer(object): if use_cuda: self.vc_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the TTS model from a directory. + + We assume the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + + config = load_config(os.path.join(model_dir, "config.json")) + self.tts_config = config + self.tts_model = setup_tts_model(config) + self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) + if use_cuda: + self.tts_model.cuda() + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: """Load the TTS model.