use synthesizer in both synthesize.py and server.pu

This commit is contained in:
Eren Gölge 2021-01-21 15:54:33 +01:00
parent 9addfabc43
commit 0ab2eb2664
3 changed files with 59 additions and 170 deletions

View File

@ -20,6 +20,7 @@ from TTS.tts.utils.io import load_checkpoint
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
@ -34,121 +35,6 @@ def str2bool(v):
raise argparse.ArgumentTypeError('Boolean value expected.')
def load_tts_model(model_path, config_path, use_cuda, speakers_json=None, speaker_idx=None):
global phonemes
global symbols
# load the config
model_config = load_config(config_path)
# load the audio processor
ap = AudioProcessor(**model_config.audio)
# if the vocabulary was passed, replace the default
if 'characters' in model_config.keys():
symbols, phonemes = make_symbols(**model_config.characters)
# load speakers
speaker_embedding = None
speaker_embedding_dim = None
num_speakers = 0
if speakers_json is not None:
speaker_mapping = json.load(open(speakers_json, 'r'))
num_speakers = len(speaker_mapping)
if model_config.use_external_speaker_embedding_file:
if speaker_idx is not None:
speaker_embedding = speaker_mapping[speaker_idx]['embedding']
else: # if speaker_idx is not specificated use the first sample in speakers.json
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']
speaker_embedding_dim = len(speaker_embedding)
# load tts model
num_chars = len(phonemes) if model_config.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, model_config, speaker_embedding_dim)
model.load_checkpoint(model_config, model_path, eval=True)
if use_cuda:
model.cuda()
return model, model_config, ap, speaker_embedding
def load_vocoder_model(model_path, config_path, use_cuda):
vocoder_config = load_config(vocoder_config_path)
vocoder_ap = AudioProcessor(**vocoder_config['audio'])
vocoder_model = setup_generator(vocoder_config)
vocoder_model.load_checkpoint(vocoder_config, model_path, eval=True)
if use_cuda:
vocoder_model.cuda()
return vocoder_model, vocoder_config, vocoder_ap
def tts(model,
vocoder_model,
text,
model_config,
vocoder_config,
use_cuda,
ap,
vocoder_ap,
use_gl,
speaker_fileid,
speaker_embedding=None,
gst_style=None):
t_1 = time.time()
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
model,
text,
model_config,
use_cuda,
ap,
speaker_fileid,
gst_style,
False,
model_config.enable_eos_bos_chars,
use_gl,
speaker_embedding=speaker_embedding)
# grab spectrogram (thx to the nice guys at mozilla discourse for codesnippet)
if args.save_spectogram:
spec_file_name = args.text.replace(" ", "_")[0:10]
spec_file_name = spec_file_name.translate(
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.npy'
spec_file_name = os.path.join(args.out_path, spec_file_name)
spectrogram = mel_postnet_spec.T
spectrogram = spectrogram[0]
np.save(spec_file_name, spectrogram)
print(" > Saving raw spectogram to " + spec_file_name)
# convert linear spectrogram to melspectrogram for tacotron
if model_config.model == "Tacotron" and not use_gl:
mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T)
# run vocoder_model
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T
device_type = "cuda" if use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = vocoder_ap._normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [1, vocoder_config['audio']['sample_rate'] / ap.sample_rate]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0)
# run vocoder model
# [1, T, C]
waveform = vocoder_model.inference(vocoder_input.to(device_type))
if use_cuda and not use_gl:
waveform = waveform.cpu()
if not use_gl:
waveform = waveform.numpy()
waveform = waveform.squeeze()
rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)
tps = (time.time() - t_1) / len(waveform)
print(" > Run-time: {}".format(time.time() - t_1))
print(" > Real-time factor: {}".format(rtf))
print(" > Time per step: {}".format(tps))
return waveform
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
@ -273,7 +159,9 @@ if __name__ == "__main__":
manager = ModelManager(path)
model_path = None
config_path = None
vocoder_path = None
vocoder_config_path = None
model = None
vocoder_model = None
vocoder_config = None
@ -302,49 +190,36 @@ if __name__ == "__main__":
# RUN THE SYNTHESIS
# load models
model, model_config, ap, speaker_embedding = load_tts_model(model_path, config_path, args.use_cuda, args.speaker_idx)
if vocoder_path is not None:
vocoder_model, vocoder_config, vocoder_ap = load_vocoder_model(vocoder_path, vocoder_config_path, use_cuda=args.use_cuda)
synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, args.use_cuda)
use_griffin_lim = vocoder_path is None
print(" > Text: {}".format(args.text))
# handle multi-speaker setting
if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None:
if args.speaker_idx.isdigit():
args.speaker_idx = int(args.speaker_idx)
else:
args.speaker_idx = None
else:
args.speaker_idx = None
# # handle multi-speaker setting
# if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None:
# if args.speaker_idx.isdigit():
# args.speaker_idx = int(args.speaker_idx)
# else:
# args.speaker_idx = None
# else:
# args.speaker_idx = None
if args.gst_style is None:
if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None:
gst_style = model_config.gst['gst_style_input']
else:
gst_style = None
else:
# check if gst_style string is a dict, if is dict convert else use string
try:
gst_style = json.loads(args.gst_style)
if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']:
raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens']))
except ValueError:
gst_style = args.gst_style
# if args.gst_style is None:
# if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None:
# gst_style = model_config.gst['gst_style_input']
# else:
# gst_style = None
# else:
# # check if gst_style string is a dict, if is dict convert else use string
# try:
# gst_style = json.loads(args.gst_style)
# if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']:
# raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens']))
# except ValueError:
# gst_style = args.gst_style
# kick it
wav = tts(model,
vocoder_model,
args.text,
model_config,
vocoder_config,
args.use_cuda,
ap,
vocoder_ap,
use_griffin_lim,
args.speaker_idx,
speaker_embedding=speaker_embedding,
gst_style=gst_style)
wav = synthesizer.tts(args.text)
# save the results
file_name = args.text.replace(" ", "_")[0:20]
@ -352,4 +227,4 @@ if __name__ == "__main__":
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
out_path = os.path.join(args.out_path, file_name)
print(" > Saving output to {}".format(out_path))
ap.save_wav(wav, out_path)
synthesizer.save_wav(wav, out_path)

View File

@ -2,10 +2,11 @@
import argparse
import os
import sys
import io
from pathlib import Path
from flask import Flask, render_template, request, send_file
from TTS.server.synthesizer import Synthesizer
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.manage import ModelManager
@ -71,10 +72,11 @@ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file):
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
args.vocoder_config = vocoder_config_file
synthesizer = Synthesizer(args)
synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)
app = Flask(__name__)
@app.route('/')
def index():
return render_template('index.html')
@ -84,8 +86,10 @@ def index():
def tts():
text = request.args.get('text')
print(" > Model input: {}".format(text))
data = synthesizer.tts(text)
return send_file(data, mimetype='audio/wav')
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype='audio/wav')
def main():

View File

@ -1,5 +1,3 @@
import io
import sys
import time
import numpy as np
@ -19,20 +17,35 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols
class Synthesizer(object):
def __init__(self, config):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda):
"""Encapsulation of tts and vocoder models for inference.
TODO: handle multi-speaker and GST inference.
Args:
tts_checkpoint (str): path to the tts model file.
tts_config (str): path to the tts config file.
vocoder_checkpoint (str): path to the vocoder model file.
vocoder_config (str): path to the vocoder config file.
use_cuda (bool): enable/disable cuda.
"""
self.tts_checkpoint = tts_checkpoint
self.tts_config = tts_config
self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config
self.use_cuda = use_cuda
self.wavernn = None
self.vocoder_model = None
self.num_speakers = 0
self.tts_speakers = None
self.config = config
self.seg = self.get_segmenter("en")
self.use_cuda = self.config.use_cuda
self.use_cuda = use_cuda
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda)
if self.config.vocoder_checkpoint:
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
self.load_tts(tts_checkpoint, tts_config,
use_cuda)
if vocoder_checkpoint:
self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
@staticmethod
def get_segmenter(lang):
@ -41,7 +54,7 @@ class Synthesizer(object):
def load_speakers(self):
# load speakers
if self.model_config.use_speaker_embedding is not None:
self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json)
self.num_speakers = len(self.tts_speakers)
else:
self.num_speakers = 0
@ -147,12 +160,9 @@ class Synthesizer(object):
wavs += list(waveform)
wavs += [0] * 10000
out = io.BytesIO()
self.save_wav(wavs, out)
# compute stats
process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio['sample_rate']
print(f" > Processing time: {process_time}")
print(f" > Real-time factor: {process_time / audio_time}")
return out
return wavs