mirror of https://github.com/coqui-ai/TTS.git
Update for tokenizer API
This commit is contained in:
parent
e1b4c4ca43
commit
acc6eef625
|
@ -114,8 +114,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
self.tts_config = load_config(tts_config_path)
|
self.tts_config = load_config(tts_config_path)
|
||||||
self.use_phonemes = self.tts_config.use_phonemes
|
self.use_phonemes = self.tts_config.use_phonemes
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||||
self.tokenizer = TTSTokenizer.init_from_config(self.tts_config)
|
|
||||||
|
|
||||||
speaker_manager = self._init_speaker_manager()
|
speaker_manager = self._init_speaker_manager()
|
||||||
language_manager = self._init_language_manager()
|
language_manager = self._init_language_manager()
|
||||||
|
@ -245,7 +244,7 @@ class Synthesizer(object):
|
||||||
path (str): output path to save the waveform.
|
path (str): output path to save the waveform.
|
||||||
"""
|
"""
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.tts_model.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
self,
|
self,
|
||||||
|
@ -333,13 +332,10 @@ class Synthesizer(object):
|
||||||
text=sen,
|
text=sen,
|
||||||
CONFIG=self.tts_config,
|
CONFIG=self.tts_config,
|
||||||
use_cuda=self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
ap=self.ap,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
language_id=language_id,
|
language_id=language_id,
|
||||||
language_name=language_name,
|
language_name=language_name,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
)
|
)
|
||||||
|
@ -347,14 +343,14 @@ class Synthesizer(object):
|
||||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||||
device_type = "cuda" if self.use_cuda else "cpu"
|
device_type = "cuda" if self.use_cuda else "cpu"
|
||||||
# renormalize spectrogram based on vocoder config
|
# renormalize spectrogram based on vocoder config
|
||||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||||
# compute scale factor for possible sample rate mismatch
|
# compute scale factor for possible sample rate mismatch
|
||||||
scale_factor = [
|
scale_factor = [
|
||||||
1,
|
1,
|
||||||
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
|
self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate,
|
||||||
]
|
]
|
||||||
if scale_factor[1] != 1:
|
if scale_factor[1] != 1:
|
||||||
print(" > interpolating tts model output.")
|
print(" > interpolating tts model output.")
|
||||||
|
@ -372,7 +368,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# trim silence
|
# trim silence
|
||||||
if self.tts_config.audio["do_trim_silence"] is True:
|
if self.tts_config.audio["do_trim_silence"] is True:
|
||||||
waveform = trim_silence(waveform, self.ap)
|
waveform = trim_silence(waveform, self.tts_model.ap)
|
||||||
|
|
||||||
wavs += list(waveform)
|
wavs += list(waveform)
|
||||||
wavs += [0] * 10000
|
wavs += [0] * 10000
|
||||||
|
|
|
@ -28,8 +28,7 @@ def setup_model(config: Coqpit):
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise ValueError(f"Model {config.model} not exist!") from e
|
raise ValueError(f"Model {config.model} not exist!") from e
|
||||||
print(" > Vocoder Model: {}".format(config.model))
|
print(" > Vocoder Model: {}".format(config.model))
|
||||||
model = MyModel(config)
|
return MyModel.init_from_config(config)
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def setup_generator(c):
|
def setup_generator(c):
|
||||||
|
|
Loading…
Reference in New Issue