Update for tokenizer API

This commit is contained in:
Eren Gölge 2021-11-24 17:49:20 +01:00
parent e1b4c4ca43
commit acc6eef625
2 changed files with 6 additions and 11 deletions

View File

@ -114,8 +114,7 @@ class Synthesizer(object):
self.tts_config = load_config(tts_config_path)
self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
self.tokenizer = TTSTokenizer.init_from_config(self.tts_config)
self.tts_model = setup_tts_model(config=self.tts_config)
speaker_manager = self._init_speaker_manager()
language_manager = self._init_language_manager()
@ -245,7 +244,7 @@ class Synthesizer(object):
path (str): output path to save the waveform.
"""
wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate)
self.tts_model.ap.save_wav(wav, path, self.output_sample_rate)
def tts(
self,
@ -333,13 +332,10 @@ class Synthesizer(object):
text=sen,
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
ap=self.ap,
tokenizer=self.tokenizer,
speaker_id=speaker_id,
language_id=language_id,
language_name=language_name,
style_wav=style_wav,
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
)
@ -347,14 +343,14 @@ class Synthesizer(object):
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [
1,
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
@ -372,7 +368,7 @@ class Synthesizer(object):
# trim silence
if self.tts_config.audio["do_trim_silence"] is True:
waveform = trim_silence(waveform, self.ap)
waveform = trim_silence(waveform, self.tts_model.ap)
wavs += list(waveform)
wavs += [0] * 10000

View File

@ -28,8 +28,7 @@ def setup_model(config: Coqpit):
except ModuleNotFoundError as e:
raise ValueError(f"Model {config.model} not exist!") from e
print(" > Vocoder Model: {}".format(config.model))
model = MyModel(config)
return model
return MyModel.init_from_config(config)
def setup_generator(c):