refactor(synthesizer): set sample rate in loading methods

This commit is contained in:
Enno Hermann 2024-12-02 22:50:33 +01:00
parent 7d0416f99b
commit 3539e65d8e
1 changed files with 6 additions and 6 deletions

View File

@ -95,26 +95,20 @@ class Synthesizer(nn.Module):
if tts_checkpoint: if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda) self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint: if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
if vc_checkpoint and model_dir is None: if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda) self._load_vc(vc_checkpoint, vc_config, use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir: if model_dir:
if "fairseq" in model_dir: if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda) self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
elif "openvoice" in model_dir: elif "openvoice" in model_dir:
self._load_openvoice_from_dir(Path(model_dir), use_cuda) self._load_openvoice_from_dir(Path(model_dir), use_cuda)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
else: else:
self._load_tts_from_dir(model_dir, use_cuda) self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod @staticmethod
def _get_segmenter(lang: str): def _get_segmenter(lang: str):
@ -143,6 +137,7 @@ class Synthesizer(nn.Module):
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
self.vc_config = load_config(vc_config_path) self.vc_config = load_config(vc_config_path)
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config) self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda: if use_cuda:
@ -157,6 +152,7 @@ class Synthesizer(nn.Module):
self.tts_model = Vits.init_from_config(self.tts_config) self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config self.tts_config = self.tts_model.config
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
@ -170,6 +166,7 @@ class Synthesizer(nn.Module):
self.vc_model = OpenVoice.init_from_config(self.vc_config) self.vc_model = OpenVoice.init_from_config(self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
self.vc_config = self.vc_model.config self.vc_config = self.vc_model.config
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if use_cuda: if use_cuda:
self.vc_model.cuda() self.vc_model.cuda()
@ -180,6 +177,7 @@ class Synthesizer(nn.Module):
""" """
config = load_config(os.path.join(model_dir, "config.json")) config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config self.tts_config = config
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config) self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda: if use_cuda:
@ -201,6 +199,7 @@ class Synthesizer(nn.Module):
""" """
# pylint: disable=global-statement # pylint: disable=global-statement
self.tts_config = load_config(tts_config_path) self.tts_config = load_config(tts_config_path)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.") raise ValueError("Phonemizer is not defined in the TTS config.")
@ -238,6 +237,7 @@ class Synthesizer(nn.Module):
use_cuda (bool): enable/disable CUDA use. use_cuda (bool): enable/disable CUDA use.
""" """
self.vocoder_config = load_config(model_config) self.vocoder_config = load_config(model_config)
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)