diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 64993754..1756b82e 100644 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -9,6 +9,7 @@ import string import time import torch +import numpy as np from TTS.tts.utils.generic_utils import setup_model, is_tacotron from TTS.tts.utils.synthesis import synthesis @@ -21,6 +22,18 @@ from TTS.vocoder.utils.generic_utils import setup_generator def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_fileid, speaker_embedding=None, gst_style=None): t_1 = time.time() waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_fileid, gst_style, False, CONFIG.enable_eos_bos_chars, use_gl, speaker_embedding=speaker_embedding) + + # grab spectrogram (thx to the nice guys at mozilla discourse for codesnipplet) + if args.save_spectogram: + spec_file_name = args.text.replace(" ", "_") + spec_file_name = spec_file_name.translate( + str.maketrans('', '', string.punctuation.replace('_', ''))) + '.npy' + spec_file_name = os.path.join(args.out_path, spec_file_name) + spectrogram = torch.FloatTensor(mel_postnet_spec.T) + spectrogram = spectrogram.unsqueeze(0) + np.save(spec_file_name, spectrogram) + print(" > Saving raw spectogram to " + spec_file_name) + if CONFIG.model == "Tacotron" and not use_gl: mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T if not use_gl: @@ -88,6 +101,11 @@ if __name__ == "__main__": '--gst_style', help="Wav path file for GST stylereference.", default=None) + parser.add_argument( + '--save_spectogram', + type=bool, + help="If true save raw spectogram for further (vocoder) processing in out_path.", + default=False) args = parser.parse_args()