mirror of https://github.com/coqui-ai/TTS.git
Server updates
This commit is contained in:
parent
800b77eb10
commit
ff33604df1
|
@ -1,7 +1,12 @@
|
|||
{
|
||||
"model_path":"/home/erogol/projects/runs/2579/keep/November-04-2018_06+19PM-TTS-master-_tmp-debug/",
|
||||
"model_name":"best_model.pth.tar",
|
||||
"model_config":"config.json",
|
||||
"tts_path":"/media/erogol/data_ssd/Data/models/ljspeech_models/ljspeech-April-08-2019_07+32PM-8a47b46/", // tts model root folder
|
||||
"tts_file":"checkpoint_261000.pth.tar", // tts checkpoint file
|
||||
"tts_config":"config.json", // tts config.json file
|
||||
"wavernn_lib_path": "/home/erogol/projects/", // Rootpath to wavernn project folder to be important. If this is none, model uses GL for speech synthesis.
|
||||
"wavernn_path":"/media/erogol/data_ssd/Data/models/wavernn/ljspeech/mold_ljspeech_best_model/", // wavernn model root path
|
||||
"wavernn_file":"checkpoint_433000.pth.tar", // wavernn checkpoint file name
|
||||
"wavernn_config":"config.json", // wavernn config file
|
||||
"is_wavernn_batched":true,
|
||||
"port": 5002,
|
||||
"use_cuda": true
|
||||
}
|
||||
|
|
|
@ -11,10 +11,7 @@ args = parser.parse_args()
|
|||
|
||||
config = load_config(args.config_path)
|
||||
app = Flask(__name__)
|
||||
synthesizer = Synthesizer()
|
||||
synthesizer.load_model(config.model_path, config.model_name,
|
||||
config.model_config, config.use_cuda)
|
||||
|
||||
synthesizer = Synthesizer(config)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
|
|
|
@ -1,40 +1,82 @@
|
|||
import io
|
||||
import os
|
||||
import librosa
|
||||
import torch
|
||||
import scipy
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from utils.text import text_to_sequence
|
||||
from utils.generic_utils import load_config
|
||||
from utils.audio import AudioProcessor
|
||||
import torch
|
||||
|
||||
from models.tacotron import Tacotron
|
||||
from matplotlib import pylab as plt
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.generic_utils import load_config, setup_model
|
||||
from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme
|
||||
|
||||
|
||||
class Synthesizer(object):
|
||||
def load_model(self, model_path, model_name, model_config, use_cuda):
|
||||
model_config = os.path.join(model_path, model_config)
|
||||
self.model_file = os.path.join(model_path, model_name)
|
||||
print(" > Loading model ...")
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", self.model_file)
|
||||
config = load_config(model_config)
|
||||
def __init__(self, config):
|
||||
self.wavernn = None
|
||||
self.config = config
|
||||
self.use_cuda = use_cuda
|
||||
self.ap = AudioProcessor(**config.audio)
|
||||
self.model = Tacotron(config.embedding_size, self.ap.num_freq, self.ap.num_mels, config.r)
|
||||
self.use_cuda = config.use_cuda
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
self.load_tts(self.config.tts_path, self.config.tts_file, self.config.tts_config, config.use_cuda)
|
||||
if self.config.wavernn_lib_path:
|
||||
self.load_wavernn(config.wavernn_lib_path, config.wavernn_path, config.wavernn_file, config.wavernn_config, config.use_cuda)
|
||||
|
||||
def load_tts(self, model_path, model_file, model_config, use_cuda):
|
||||
tts_config = os.path.join(model_path, model_config)
|
||||
self.model_file = os.path.join(model_path, model_file)
|
||||
print(" > Loading TTS model ...")
|
||||
print(" | > model config: ", tts_config)
|
||||
print(" | > model file: ", model_file)
|
||||
self.tts_config = load_config(tts_config)
|
||||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(**self.tts_config.audio)
|
||||
if self.use_phonemes:
|
||||
self.input_size = len(phonemes)
|
||||
self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.tts_config.text_cleaner], self.tts_config.phoneme_language, self.tts_config.enable_eos_bos_chars)
|
||||
else:
|
||||
self.input_size = len(symbols)
|
||||
self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner])
|
||||
self.tts_model = setup_model(self.input_size, self.tts_config)
|
||||
# load model state
|
||||
if use_cuda:
|
||||
cp = torch.load(self.model_file)
|
||||
else:
|
||||
cp = torch.load(
|
||||
self.model_file, map_location=lambda storage, loc: storage)
|
||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
||||
# load the model
|
||||
self.model.load_state_dict(cp['model'])
|
||||
self.tts_model.load_state_dict(cp['model'])
|
||||
if use_cuda:
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
self.tts_model.cuda()
|
||||
self.tts_model.eval()
|
||||
|
||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||
sys.path.append(lib_path) # set this if TTS is not installed globally
|
||||
from WaveRNN.models.wavernn import Model
|
||||
wavernn_config = os.path.join(model_path, model_config)
|
||||
model_file = os.path.join(model_path, model_file)
|
||||
print(" > Loading WaveRNN model ...")
|
||||
print(" | > model config: ", wavernn_config)
|
||||
print(" | > model file: ", model_file)
|
||||
self.wavernn_config = load_config(wavernn_config)
|
||||
self.wavernn = Model(
|
||||
rnn_dims=512,
|
||||
fc_dims=512,
|
||||
mode=self.wavernn_config.mode,
|
||||
pad=2,
|
||||
upsample_factors=self.wavernn_config.upsample_factors, # set this depending on dataset
|
||||
feat_dims=80,
|
||||
compute_dims=128,
|
||||
res_out_dims=128,
|
||||
res_blocks=10,
|
||||
hop_length=self.ap.hop_length,
|
||||
sample_rate=self.ap.sample_rate,
|
||||
).cuda()
|
||||
|
||||
check = torch.load(model_file)
|
||||
self.wavernn.load_state_dict(check['model'])
|
||||
if use_cuda:
|
||||
self.wavernn.cuda()
|
||||
self.wavernn.eval()
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||
|
@ -42,25 +84,35 @@ class Synthesizer(object):
|
|||
self.ap.save_wav(wav, path)
|
||||
|
||||
def tts(self, text):
|
||||
text_cleaner = [self.config.text_cleaner]
|
||||
wavs = []
|
||||
for sen in text.split('.'):
|
||||
if len(sen) < 3:
|
||||
continue
|
||||
sen = sen.strip()
|
||||
sen += '.'
|
||||
print(sen)
|
||||
sen = sen.strip()
|
||||
seq = np.array(text_to_sequence(sen, text_cleaner))
|
||||
|
||||
seq = np.array(self.input_adapter(sen))
|
||||
text_hat = sequence_to_phoneme(seq)
|
||||
print(text_hat)
|
||||
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
||||
|
||||
if self.use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_out, linear_out, alignments, stop_tokens = self.model.forward(
|
||||
decoder_out, postnet_out, alignments, stop_tokens = self.tts_model.inference(
|
||||
chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = self.ap.inv_spectrogram(linear_out.T)
|
||||
out = io.BytesIO()
|
||||
postnet_out = postnet_out[0].data.cpu().numpy()
|
||||
if self.tts_config.model == "Tacotron":
|
||||
wav = self.ap.inv_spectrogram(postnet_out.T)
|
||||
elif self.tts_config.model == "Tacotron2":
|
||||
if self.wavernn:
|
||||
wav = self.wavernn.generate(torch.FloatTensor(postnet_out.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||
else:
|
||||
wav = self.ap.inv_mel_spectrogram(postnet_out.T)
|
||||
wavs += list(wav)
|
||||
wavs += [0] * 10000
|
||||
|
||||
out = io.BytesIO()
|
||||
self.save_wav(wavs, out)
|
||||
return out
|
|
@ -56,11 +56,10 @@
|
|||
<div class="container">
|
||||
<div class="row">
|
||||
<div class="col-lg-12 text-center">
|
||||
<h1 class="mt-5">Mozilla TTS</h1>
|
||||
<p class="lead">"work-in-progress"</p>
|
||||
<img class="mt-5" src="https://user-images.githubusercontent.com/1402048/52643646-c2102980-2edd-11e9-8c37-b72f3c89a640.png" alt=></img>
|
||||
<ul class="list-unstyled">
|
||||
</ul>
|
||||
<input id="text" placeholder="Enter text" size=45 type="text" name="text">
|
||||
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
||||
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
||||
<audio id="audio" controls autoplay hidden></audio>
|
||||
<p id="message"></p>
|
||||
|
|
Loading…
Reference in New Issue