refactor(synthesizer): set sample rate in loading methods

This commit is contained in:
Enno Hermann 2024-12-02 22:50:33 +01:00
parent 7d0416f99b
commit 3539e65d8e
1 changed files with 6 additions and 6 deletions

View File

@ -95,26 +95,20 @@ class Synthesizer(nn.Module):
if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
@ -143,6 +137,7 @@ class Synthesizer(nn.Module):
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
@ -157,6 +152,7 @@ class Synthesizer(nn.Module):
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda:
self.tts_model.cuda()
@ -170,6 +166,7 @@ class Synthesizer(nn.Module):
self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if use_cuda:
self.vc_model.cuda()
@ -180,6 +177,7 @@ class Synthesizer(nn.Module):
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
@ -201,6 +199,7 @@ class Synthesizer(nn.Module):
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")
@ -238,6 +237,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)