mirror of https://github.com/coqui-ai/TTS.git
refactor(synthesizer): set sample rate in loading methods
This commit is contained in:
parent
7d0416f99b
commit
3539e65d8e
|
@ -95,26 +95,20 @@ class Synthesizer(nn.Module):
|
|||
|
||||
if tts_checkpoint:
|
||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
|
||||
if vocoder_checkpoint:
|
||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
if vc_checkpoint and model_dir is None:
|
||||
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
|
||||
if model_dir:
|
||||
if "fairseq" in model_dir:
|
||||
self._load_fairseq_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
elif "openvoice" in model_dir:
|
||||
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
else:
|
||||
self._load_tts_from_dir(model_dir, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def _get_segmenter(lang: str):
|
||||
|
@ -143,6 +137,7 @@ class Synthesizer(nn.Module):
|
|||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.vc_config = load_config(vc_config_path)
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
self.vc_model = setup_vc_model(config=self.vc_config)
|
||||
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
|
||||
if use_cuda:
|
||||
|
@ -157,6 +152,7 @@ class Synthesizer(nn.Module):
|
|||
self.tts_model = Vits.init_from_config(self.tts_config)
|
||||
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
|
||||
self.tts_config = self.tts_model.config
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
||||
|
@ -170,6 +166,7 @@ class Synthesizer(nn.Module):
|
|||
self.vc_model = OpenVoice.init_from_config(self.vc_config)
|
||||
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
|
||||
self.vc_config = self.vc_model.config
|
||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||
if use_cuda:
|
||||
self.vc_model.cuda()
|
||||
|
||||
|
@ -180,6 +177,7 @@ class Synthesizer(nn.Module):
|
|||
"""
|
||||
config = load_config(os.path.join(model_dir, "config.json"))
|
||||
self.tts_config = config
|
||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||
self.tts_model = setup_tts_model(config)
|
||||
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
|
||||
if use_cuda:
|
||||
|
@ -201,6 +199,7 @@ class Synthesizer(nn.Module):
|
|||
"""
|
||||
# pylint: disable=global-statement
|
||||
self.tts_config = load_config(tts_config_path)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
|
||||
raise ValueError("Phonemizer is not defined in the TTS config.")
|
||||
|
||||
|
@ -238,6 +237,7 @@ class Synthesizer(nn.Module):
|
|||
use_cuda (bool): enable/disable CUDA use.
|
||||
"""
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
|
||||
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
|
||||
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
|
||||
|
|
Loading…
Reference in New Issue