diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index a4d70324..6cd28bc8 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -20,6 +20,7 @@ from TTS.tts.utils.io import load_checkpoint from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input @@ -34,121 +35,6 @@ def str2bool(v): raise argparse.ArgumentTypeError('Boolean value expected.') -def load_tts_model(model_path, config_path, use_cuda, speakers_json=None, speaker_idx=None): - global phonemes - global symbols - - # load the config - model_config = load_config(config_path) - - # load the audio processor - ap = AudioProcessor(**model_config.audio) - - # if the vocabulary was passed, replace the default - if 'characters' in model_config.keys(): - symbols, phonemes = make_symbols(**model_config.characters) - - # load speakers - speaker_embedding = None - speaker_embedding_dim = None - num_speakers = 0 - if speakers_json is not None: - speaker_mapping = json.load(open(speakers_json, 'r')) - num_speakers = len(speaker_mapping) - if model_config.use_external_speaker_embedding_file: - if speaker_idx is not None: - speaker_embedding = speaker_mapping[speaker_idx]['embedding'] - else: # if speaker_idx is not specificated use the first sample in speakers.json - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'] - speaker_embedding_dim = len(speaker_embedding) - - # load tts model - num_chars = len(phonemes) if model_config.use_phonemes else len(symbols) - model = setup_model(num_chars, num_speakers, model_config, speaker_embedding_dim) - model.load_checkpoint(model_config, model_path, eval=True) - if use_cuda: - model.cuda() - return model, model_config, ap, speaker_embedding - - -def load_vocoder_model(model_path, config_path, use_cuda): - vocoder_config = load_config(vocoder_config_path) - vocoder_ap = AudioProcessor(**vocoder_config['audio']) - vocoder_model = setup_generator(vocoder_config) - vocoder_model.load_checkpoint(vocoder_config, model_path, eval=True) - if use_cuda: - vocoder_model.cuda() - return vocoder_model, vocoder_config, vocoder_ap - - -def tts(model, - vocoder_model, - text, - model_config, - vocoder_config, - use_cuda, - ap, - vocoder_ap, - use_gl, - speaker_fileid, - speaker_embedding=None, - gst_style=None): - t_1 = time.time() - waveform, _, _, mel_postnet_spec, _, _ = synthesis( - model, - text, - model_config, - use_cuda, - ap, - speaker_fileid, - gst_style, - False, - model_config.enable_eos_bos_chars, - use_gl, - speaker_embedding=speaker_embedding) - # grab spectrogram (thx to the nice guys at mozilla discourse for codesnippet) - if args.save_spectogram: - spec_file_name = args.text.replace(" ", "_")[0:10] - spec_file_name = spec_file_name.translate( - str.maketrans('', '', string.punctuation.replace('_', ''))) + '.npy' - spec_file_name = os.path.join(args.out_path, spec_file_name) - spectrogram = mel_postnet_spec.T - spectrogram = spectrogram[0] - np.save(spec_file_name, spectrogram) - print(" > Saving raw spectogram to " + spec_file_name) - # convert linear spectrogram to melspectrogram for tacotron - if model_config.model == "Tacotron" and not use_gl: - mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T) - # run vocoder_model - if not use_gl: - # denormalize tts output based on tts audio config - mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T - device_type = "cuda" if use_cuda else "cpu" - # renormalize spectrogram based on vocoder config - vocoder_input = vocoder_ap._normalize(mel_postnet_spec.T) - # compute scale factor for possible sample rate mismatch - scale_factor = [1, vocoder_config['audio']['sample_rate'] / ap.sample_rate] - if scale_factor[1] != 1: - print(" > interpolating tts model output.") - vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) - else: - vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) - # run vocoder model - # [1, T, C] - waveform = vocoder_model.inference(vocoder_input.to(device_type)) - if use_cuda and not use_gl: - waveform = waveform.cpu() - if not use_gl: - waveform = waveform.numpy() - waveform = waveform.squeeze() - rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate) - tps = (time.time() - t_1) / len(waveform) - print(" > Run-time: {}".format(time.time() - t_1)) - print(" > Real-time factor: {}".format(rtf)) - print(" > Time per step: {}".format(tps)) - return waveform - - if __name__ == "__main__": parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n''' @@ -273,7 +159,9 @@ if __name__ == "__main__": manager = ModelManager(path) model_path = None + config_path = None vocoder_path = None + vocoder_config_path = None model = None vocoder_model = None vocoder_config = None @@ -302,49 +190,36 @@ if __name__ == "__main__": # RUN THE SYNTHESIS # load models - model, model_config, ap, speaker_embedding = load_tts_model(model_path, config_path, args.use_cuda, args.speaker_idx) - if vocoder_path is not None: - vocoder_model, vocoder_config, vocoder_ap = load_vocoder_model(vocoder_path, vocoder_config_path, use_cuda=args.use_cuda) + synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, args.use_cuda) use_griffin_lim = vocoder_path is None print(" > Text: {}".format(args.text)) - # handle multi-speaker setting - if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None: - if args.speaker_idx.isdigit(): - args.speaker_idx = int(args.speaker_idx) - else: - args.speaker_idx = None - else: - args.speaker_idx = None + # # handle multi-speaker setting + # if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None: + # if args.speaker_idx.isdigit(): + # args.speaker_idx = int(args.speaker_idx) + # else: + # args.speaker_idx = None + # else: + # args.speaker_idx = None - if args.gst_style is None: - if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None: - gst_style = model_config.gst['gst_style_input'] - else: - gst_style = None - else: - # check if gst_style string is a dict, if is dict convert else use string - try: - gst_style = json.loads(args.gst_style) - if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']: - raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens'])) - except ValueError: - gst_style = args.gst_style + # if args.gst_style is None: + # if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None: + # gst_style = model_config.gst['gst_style_input'] + # else: + # gst_style = None + # else: + # # check if gst_style string is a dict, if is dict convert else use string + # try: + # gst_style = json.loads(args.gst_style) + # if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']: + # raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens'])) + # except ValueError: + # gst_style = args.gst_style # kick it - wav = tts(model, - vocoder_model, - args.text, - model_config, - vocoder_config, - args.use_cuda, - ap, - vocoder_ap, - use_griffin_lim, - args.speaker_idx, - speaker_embedding=speaker_embedding, - gst_style=gst_style) + wav = synthesizer.tts(args.text) # save the results file_name = args.text.replace(" ", "_")[0:20] @@ -352,4 +227,4 @@ if __name__ == "__main__": str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav' out_path = os.path.join(args.out_path, file_name) print(" > Saving output to {}".format(out_path)) - ap.save_wav(wav, out_path) + synthesizer.save_wav(wav, out_path) diff --git a/TTS/server/server.py b/TTS/server/server.py index ed98d35e..a89f4021 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -2,10 +2,11 @@ import argparse import os import sys +import io from pathlib import Path from flask import Flask, render_template, request, send_file -from TTS.server.synthesizer import Synthesizer +from TTS.utils.synthesizer import Synthesizer from TTS.utils.manage import ModelManager @@ -71,10 +72,11 @@ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file): if not args.vocoder_config and os.path.isfile(vocoder_config_file): args.vocoder_config = vocoder_config_file -synthesizer = Synthesizer(args) +synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda) app = Flask(__name__) + @app.route('/') def index(): return render_template('index.html') @@ -84,8 +86,10 @@ def index(): def tts(): text = request.args.get('text') print(" > Model input: {}".format(text)) - data = synthesizer.tts(text) - return send_file(data, mimetype='audio/wav') + wavs = synthesizer.tts(text) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype='audio/wav') def main(): diff --git a/TTS/server/synthesizer.py b/TTS/utils/synthesizer.py similarity index 83% rename from TTS/server/synthesizer.py rename to TTS/utils/synthesizer.py index a76badd6..f7ca5f44 100644 --- a/TTS/server/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,5 +1,3 @@ -import io -import sys import time import numpy as np @@ -19,20 +17,35 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols class Synthesizer(object): - def __init__(self, config): + def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda): + """Encapsulation of tts and vocoder models for inference. + + TODO: handle multi-speaker and GST inference. + + Args: + tts_checkpoint (str): path to the tts model file. + tts_config (str): path to the tts config file. + vocoder_checkpoint (str): path to the vocoder model file. + vocoder_config (str): path to the vocoder config file. + use_cuda (bool): enable/disable cuda. + """ + self.tts_checkpoint = tts_checkpoint + self.tts_config = tts_config + self.vocoder_checkpoint = vocoder_checkpoint + self.vocoder_config = vocoder_config + self.use_cuda = use_cuda self.wavernn = None self.vocoder_model = None self.num_speakers = 0 self.tts_speakers = None - self.config = config self.seg = self.get_segmenter("en") - self.use_cuda = self.config.use_cuda + self.use_cuda = use_cuda if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." - self.load_tts(self.config.tts_checkpoint, self.config.tts_config, - self.config.use_cuda) - if self.config.vocoder_checkpoint: - self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda) + self.load_tts(tts_checkpoint, tts_config, + use_cuda) + if vocoder_checkpoint: + self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) @staticmethod def get_segmenter(lang): @@ -41,7 +54,7 @@ class Synthesizer(object): def load_speakers(self): # load speakers if self.model_config.use_speaker_embedding is not None: - self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) + self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json) self.num_speakers = len(self.tts_speakers) else: self.num_speakers = 0 @@ -147,12 +160,9 @@ class Synthesizer(object): wavs += list(waveform) wavs += [0] * 10000 - out = io.BytesIO() - self.save_wav(wavs, out) - # compute stats process_time = time.time() - start_time audio_time = len(wavs) / self.tts_config.audio['sample_rate'] print(f" > Processing time: {process_time}") print(f" > Real-time factor: {process_time / audio_time}") - return out + return wavs