Add load from dir to synthesizer

This commit is contained in:
Eren G??lge 2023-05-06 11:33:48 +02:00
parent ae8e26f084
commit f29882c0ab
1 changed files with 15 additions and 3 deletions

View File

@ -1,3 +1,4 @@
import os
import time
from typing import List
@ -8,7 +9,6 @@ import torch
from TTS.config import load_config
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.tortoise import init_from_config
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
@ -98,8 +98,7 @@ class Synthesizer(object):
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
self.tts_model = init_from_config(TortoiseConfig(model_dir=model_dir))
self.tts_config = TortoiseConfig(model_dir=model_dir)
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
@ -134,6 +133,19 @@ class Synthesizer(object):
if use_cuda:
self.vc_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
self.tts_model.cuda()
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
"""Load the TTS model.