mirror of https://github.com/coqui-ai/TTS.git
update server to enable native vocoder inference and remove pwgan support
This commit is contained in:
parent
77c962561e
commit
b1935c97fa
|
@ -21,6 +21,8 @@ def create_argparser():
|
|||
parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
|
||||
parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.')
|
||||
parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.')
|
||||
parser.add_argument('--vocoder_config', type=str, default=None, help='path to TTS.vocoder config file.')
|
||||
parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to TTS.vocoder checkpoint file.')
|
||||
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
||||
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
||||
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
||||
|
@ -35,6 +37,11 @@ embedded_tts_folder = os.path.join(embedded_models_folder, 'tts')
|
|||
tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar')
|
||||
tts_config_file = os.path.join(embedded_tts_folder, 'config.json')
|
||||
|
||||
embedded_vocoder_folder = os.path.join(embedded_models_folder, 'vocoder')
|
||||
vocoder_checkpoint_file = os.path.join(embedded_vocoder_folder, 'checkpoint.pth.tar')
|
||||
vocoder_config_file = os.path.join(embedded_vocoder_folder, 'config.json')
|
||||
|
||||
# These models are soon to be deprecated
|
||||
embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn')
|
||||
wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar')
|
||||
wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json')
|
||||
|
@ -50,6 +57,11 @@ if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file):
|
|||
args.tts_checkpoint = tts_checkpoint_file
|
||||
if not args.tts_config and os.path.isfile(tts_config_file):
|
||||
args.tts_config = tts_config_file
|
||||
if not args.vocoder_checkpoint and os.path.isfile(tts_checkpoint_file):
|
||||
args.tts_checkpoint = tts_checkpoint_file
|
||||
if not args.vocoder_config and os.path.isfile(tts_config_file):
|
||||
args.tts_config = tts_config_file
|
||||
|
||||
if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file):
|
||||
args.wavernn_file = wavernn_checkpoint_file
|
||||
if not args.wavernn_config and os.path.isfile(wavernn_config_file):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import io
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -10,6 +11,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.io import load_config
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.speakers import load_speaker_mapping
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.utils.synthesis import *
|
||||
|
@ -34,6 +36,8 @@ class Synthesizer(object):
|
|||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
||||
self.config.use_cuda)
|
||||
if self.config.vocoder_checkpoint:
|
||||
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
|
||||
if self.config.wavernn_lib_path:
|
||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
||||
self.config.wavernn_config, self.config.use_cuda)
|
||||
|
@ -77,6 +81,19 @@ class Synthesizer(object):
|
|||
self.tts_model.decoder.max_decoder_steps = 3000
|
||||
if 'r' in cp:
|
||||
self.tts_model.decoder.set_r(cp['r'])
|
||||
print(f" > model reduction factor: {cp['r']}")
|
||||
|
||||
def load_vocoder(self, model_file, model_config, use_cuda):
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_model = setup_generator(self.vocoder_config)
|
||||
self.vocoder_model.load_state_dict(torch.load(model_file, map_location="cpu")["model"])
|
||||
self.vocoder_model.remove_weight_norm()
|
||||
self.vocoder_model.inference_padding = 0
|
||||
self.vocoder_config = load_config(model_config)
|
||||
|
||||
if use_cuda:
|
||||
self.vocoder_model.cuda()
|
||||
self.vocoder_model.eval()
|
||||
|
||||
def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
|
||||
# TODO: set a function in wavernn code base for model setup and call it here.
|
||||
|
@ -171,6 +188,7 @@ class Synthesizer(object):
|
|||
return sentences
|
||||
|
||||
def tts(self, text, speaker_id=None):
|
||||
start_time = time.time()
|
||||
wavs = []
|
||||
sens = self.split_into_sentences(text)
|
||||
print(sens)
|
||||
|
@ -184,29 +202,25 @@ class Synthesizer(object):
|
|||
inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
# synthesize voice
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
self.tts_model, inputs, self.tts_config, False, speaker_id, None)
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
|
||||
# convert outputs to numpy
|
||||
postnet_output, decoder_output, _, _ = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
|
||||
if self.pwgan:
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
if self.vocoder_model:
|
||||
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
|
||||
wav = self.vocoder_model.inference(vocoder_input)
|
||||
if self.use_cuda:
|
||||
vocoder_input.cuda()
|
||||
wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length)
|
||||
wav = wav.cpu().numpy()
|
||||
else:
|
||||
wav = wav.numpy()
|
||||
wav = wav.flatten()
|
||||
elif self.wavernn:
|
||||
vocoder_input = None
|
||||
if self.tts_config.model == "Tacotron":
|
||||
vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0)
|
||||
else:
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
|
||||
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
|
||||
if self.use_cuda:
|
||||
vocoder_input.cuda()
|
||||
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||
else:
|
||||
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
|
||||
# trim silence
|
||||
wav = trim_silence(wav, self.ap)
|
||||
|
||||
|
@ -215,4 +229,10 @@ class Synthesizer(object):
|
|||
|
||||
out = io.BytesIO()
|
||||
self.save_wav(wavs, out)
|
||||
|
||||
# compute stats
|
||||
process_time = time.time() - start_time
|
||||
audio_time = len(wavs) / self.tts_config.audio['sample_rate']
|
||||
print(f" > Processing time: {process_time}")
|
||||
print(f" > Real-time factor: {process_time / audio_time}")
|
||||
return out
|
||||
|
|
|
@ -8,10 +8,10 @@
|
|||
<meta name="description" content="">
|
||||
<meta name="author" content="">
|
||||
|
||||
<title>Mozillia - Text2Speech engine</title>
|
||||
<title>Mozilla - Text2Speech engine</title>
|
||||
|
||||
<!-- Bootstrap core CSS -->
|
||||
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
|
||||
<link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
|
||||
integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous" rel="stylesheet">
|
||||
|
||||
<!-- Custom styles for this template -->
|
||||
|
@ -27,7 +27,7 @@
|
|||
|
||||
</style>
|
||||
</head>
|
||||
|
||||
|
||||
<body>
|
||||
<a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;" src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
|
||||
|
||||
|
@ -60,7 +60,7 @@
|
|||
<h1 class="mt-5">Mozilla TTS</h1>
|
||||
<ul class="list-unstyled">
|
||||
</ul>
|
||||
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
||||
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
||||
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
||||
<audio id="audio" controls autoplay hidden></audio>
|
||||
<p id="message"></p>
|
||||
|
|
Loading…
Reference in New Issue