mirror of https://github.com/coqui-ai/TTS.git
Add load from dir to synthesizer
This commit is contained in:
parent
ae8e26f084
commit
f29882c0ab
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
@ -8,7 +9,6 @@ import torch
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
from TTS.tts.models.tortoise import init_from_config
|
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
@ -98,8 +98,7 @@ class Synthesizer(object):
|
||||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
if model_dir:
|
if model_dir:
|
||||||
self.tts_model = init_from_config(TortoiseConfig(model_dir=model_dir))
|
self._load_tts_from_dir(model_dir, use_cuda)
|
||||||
self.tts_config = TortoiseConfig(model_dir=model_dir)
|
|
||||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -134,6 +133,19 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.vc_model.cuda()
|
self.vc_model.cuda()
|
||||||
|
|
||||||
|
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||||
|
"""Load the TTS model from a directory.
|
||||||
|
|
||||||
|
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = load_config(os.path.join(model_dir, "config.json"))
|
||||||
|
self.tts_config = config
|
||||||
|
self.tts_model = setup_tts_model(config)
|
||||||
|
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
|
||||||
|
if use_cuda:
|
||||||
|
self.tts_model.cuda()
|
||||||
|
|
||||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model.
|
"""Load the TTS model.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue