From acc6eef625d0e65aa9c762aff3a3048898864688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 24 Nov 2021 17:49:20 +0100 Subject: [PATCH] Update for tokenizer API --- TTS/utils/synthesizer.py | 14 +++++--------- TTS/vocoder/models/__init__.py | 3 +-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 12b71ab6..2e4f4735 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -114,8 +114,7 @@ class Synthesizer(object): self.tts_config = load_config(tts_config_path) self.use_phonemes = self.tts_config.use_phonemes - self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - self.tokenizer = TTSTokenizer.init_from_config(self.tts_config) + self.tts_model = setup_tts_model(config=self.tts_config) speaker_manager = self._init_speaker_manager() language_manager = self._init_language_manager() @@ -245,7 +244,7 @@ class Synthesizer(object): path (str): output path to save the waveform. """ wav = np.array(wav) - self.ap.save_wav(wav, path, self.output_sample_rate) + self.tts_model.ap.save_wav(wav, path, self.output_sample_rate) def tts( self, @@ -333,13 +332,10 @@ class Synthesizer(object): text=sen, CONFIG=self.tts_config, use_cuda=self.use_cuda, - ap=self.ap, - tokenizer=self.tokenizer, speaker_id=speaker_id, language_id=language_id, language_name=language_name, style_wav=style_wav, - enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars, use_griffin_lim=use_gl, d_vector=speaker_embedding, ) @@ -347,14 +343,14 @@ class Synthesizer(object): mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() if not use_gl: # denormalize tts output based on tts audio config - mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch scale_factor = [ 1, - self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: print(" > interpolating tts model output.") @@ -372,7 +368,7 @@ class Synthesizer(object): # trim silence if self.tts_config.audio["do_trim_silence"] is True: - waveform = trim_silence(waveform, self.ap) + waveform = trim_silence(waveform, self.tts_model.ap) wavs += list(waveform) wavs += [0] * 10000 diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index a70ebe40..65901617 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -28,8 +28,7 @@ def setup_model(config: Coqpit): except ModuleNotFoundError as e: raise ValueError(f"Model {config.model} not exist!") from e print(" > Vocoder Model: {}".format(config.model)) - model = MyModel(config) - return model + return MyModel.init_from_config(config) def setup_generator(c):