mirror of https://github.com/coqui-ai/TTS.git
Add load from dir to synthesizer
This commit is contained in:
parent
ae8e26f084
commit
f29882c0ab
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
@ -8,7 +9,6 @@ import torch
|
|||
from TTS.config import load_config
|
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.models.tortoise import init_from_config
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
|
@ -98,8 +98,7 @@ class Synthesizer(object):
|
|||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
|
||||
if model_dir:
|
||||
self.tts_model = init_from_config(TortoiseConfig(model_dir=model_dir))
|
||||
self.tts_config = TortoiseConfig(model_dir=model_dir)
|
||||
self._load_tts_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
|
@ -134,6 +133,19 @@ class Synthesizer(object):
|
|||
if use_cuda:
|
||||
self.vc_model.cuda()
|
||||
|
||||
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model from a directory.
|
||||
|
||||
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||
"""
|
||||
|
||||
config = load_config(os.path.join(model_dir, "config.json"))
|
||||
self.tts_config = config
|
||||
self.tts_model = setup_tts_model(config)
|
||||
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
||||
|
|
Loading…
Reference in New Issue