From b1935c97fa1175908c579a4db06214174253f5f4 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 26 Jun 2020 14:35:40 +0200 Subject: [PATCH] update server to enable native vocoder inference and remove pwgan support --- server/server.py | 12 ++++++++++ server/synthesizer.py | 46 ++++++++++++++++++++++++++----------- server/templates/index.html | 8 +++---- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/server/server.py b/server/server.py index 593eeb18..43f1b3c4 100644 --- a/server/server.py +++ b/server/server.py @@ -21,6 +21,8 @@ def create_argparser(): parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') + parser.add_argument('--vocoder_config', type=str, default=None, help='path to TTS.vocoder config file.') + parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to TTS.vocoder checkpoint file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') @@ -35,6 +37,11 @@ embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') tts_config_file = os.path.join(embedded_tts_folder, 'config.json') +embedded_vocoder_folder = os.path.join(embedded_models_folder, 'vocoder') +vocoder_checkpoint_file = os.path.join(embedded_vocoder_folder, 'checkpoint.pth.tar') +vocoder_config_file = os.path.join(embedded_vocoder_folder, 'config.json') + +# These models are soon to be deprecated embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') @@ -50,6 +57,11 @@ if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): args.tts_checkpoint = tts_checkpoint_file if not args.tts_config and os.path.isfile(tts_config_file): args.tts_config = tts_config_file +if not args.vocoder_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.vocoder_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file + if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): args.wavernn_file = wavernn_checkpoint_file if not args.wavernn_config and os.path.isfile(wavernn_config_file): diff --git a/server/synthesizer.py b/server/synthesizer.py index c6fde902..d85bbebc 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -1,6 +1,7 @@ import io import re import sys +import time import numpy as np import torch @@ -10,6 +11,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.utils.generic_utils import setup_model from TTS.utils.speakers import load_speaker_mapping +from TTS.vocoder.utils.generic_utils import setup_generator # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.utils.synthesis import * @@ -34,6 +36,8 @@ class Synthesizer(object): assert torch.cuda.is_available(), "CUDA is not availabe on this machine." self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.config.use_cuda) + if self.config.vocoder_checkpoint: + self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda) if self.config.wavernn_lib_path: self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, self.config.wavernn_config, self.config.use_cuda) @@ -77,6 +81,19 @@ class Synthesizer(object): self.tts_model.decoder.max_decoder_steps = 3000 if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) + print(f" > model reduction factor: {cp['r']}") + + def load_vocoder(self, model_file, model_config, use_cuda): + self.vocoder_config = load_config(model_config) + self.vocoder_model = setup_generator(self.vocoder_config) + self.vocoder_model.load_state_dict(torch.load(model_file, map_location="cpu")["model"]) + self.vocoder_model.remove_weight_norm() + self.vocoder_model.inference_padding = 0 + self.vocoder_config = load_config(model_config) + + if use_cuda: + self.vocoder_model.cuda() + self.vocoder_model.eval() def load_wavernn(self, lib_path, model_file, model_config, use_cuda): # TODO: set a function in wavernn code base for model setup and call it here. @@ -171,6 +188,7 @@ class Synthesizer(object): return sentences def tts(self, text, speaker_id=None): + start_time = time.time() wavs = [] sens = self.split_into_sentences(text) print(sens) @@ -184,29 +202,25 @@ class Synthesizer(object): inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda) inputs = inputs.unsqueeze(0) # synthesize voice - decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( - self.tts_model, inputs, self.tts_config, False, speaker_id, None) + decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None) # convert outputs to numpy - postnet_output, decoder_output, _, _ = parse_outputs_torch( - postnet_output, decoder_output, alignments, stop_tokens) - - if self.pwgan: - vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.vocoder_model: + vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) + wav = self.vocoder_model.inference(vocoder_input) if self.use_cuda: - vocoder_input.cuda() - wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length) + wav = wav.cpu().numpy() + else: + wav = wav.numpy() + wav = wav.flatten() elif self.wavernn: vocoder_input = None if self.tts_config.model == "Tacotron": vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0) else: - vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) - + vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) if self.use_cuda: vocoder_input.cuda() wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) - else: - wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) # trim silence wav = trim_silence(wav, self.ap) @@ -215,4 +229,10 @@ class Synthesizer(object): out = io.BytesIO() self.save_wav(wavs, out) + + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio['sample_rate'] + print(f" > Processing time: {process_time}") + print(f" > Real-time factor: {process_time / audio_time}") return out diff --git a/server/templates/index.html b/server/templates/index.html index d1bde024..45b874a9 100644 --- a/server/templates/index.html +++ b/server/templates/index.html @@ -8,10 +8,10 @@ - Mozillia - Text2Speech engine + Mozilla - Text2Speech engine - @@ -27,7 +27,7 @@ - + Fork me on GitHub @@ -60,7 +60,7 @@

Mozilla TTS

- +