diff --git a/server/conf.json b/server/conf.json index a19f2ac1..f1813073 100644 --- a/server/conf.json +++ b/server/conf.json @@ -1,7 +1,12 @@ { - "model_path":"/home/erogol/projects/runs/2579/keep/November-04-2018_06+19PM-TTS-master-_tmp-debug/", - "model_name":"best_model.pth.tar", - "model_config":"config.json", + "tts_path":"/media/erogol/data_ssd/Data/models/ljspeech_models/ljspeech-April-08-2019_07+32PM-8a47b46/", // tts model root folder + "tts_file":"checkpoint_261000.pth.tar", // tts checkpoint file + "tts_config":"config.json", // tts config.json file + "wavernn_lib_path": "/home/erogol/projects/", // Rootpath to wavernn project folder to be important. If this is none, model uses GL for speech synthesis. + "wavernn_path":"/media/erogol/data_ssd/Data/models/wavernn/ljspeech/mold_ljspeech_best_model/", // wavernn model root path + "wavernn_file":"checkpoint_433000.pth.tar", // wavernn checkpoint file name + "wavernn_config":"config.json", // wavernn config file + "is_wavernn_batched":true, "port": 5002, "use_cuda": true } diff --git a/server/server.py b/server/server.py index d5effaed..f5ad4088 100644 --- a/server/server.py +++ b/server/server.py @@ -11,10 +11,7 @@ args = parser.parse_args() config = load_config(args.config_path) app = Flask(__name__) -synthesizer = Synthesizer() -synthesizer.load_model(config.model_path, config.model_name, - config.model_config, config.use_cuda) - +synthesizer = Synthesizer(config) @app.route('/') def index(): diff --git a/server/synthesizer.py b/server/synthesizer.py index 5c88c309..b8198978 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -1,40 +1,82 @@ import io import os -import librosa -import torch -import scipy +import sys + import numpy as np -import soundfile as sf -from utils.text import text_to_sequence -from utils.generic_utils import load_config -from utils.audio import AudioProcessor +import torch + from models.tacotron import Tacotron -from matplotlib import pylab as plt +from utils.audio import AudioProcessor +from utils.generic_utils import load_config, setup_model +from utils.text import phoneme_to_sequence, phonemes, symbols, text_to_sequence, sequence_to_phoneme class Synthesizer(object): - def load_model(self, model_path, model_name, model_config, use_cuda): - model_config = os.path.join(model_path, model_config) - self.model_file = os.path.join(model_path, model_name) - print(" > Loading model ...") - print(" | > model config: ", model_config) - print(" | > model file: ", self.model_file) - config = load_config(model_config) - self.config = config - self.use_cuda = use_cuda - self.ap = AudioProcessor(**config.audio) - self.model = Tacotron(config.embedding_size, self.ap.num_freq, self.ap.num_mels, config.r) + def __init__(self, config): + self.wavernn = None + self.config = config + self.use_cuda = config.use_cuda + if self.use_cuda: + assert torch.cuda.is_available(), "CUDA is not availabe on this machine." + self.load_tts(self.config.tts_path, self.config.tts_file, self.config.tts_config, config.use_cuda) + if self.config.wavernn_lib_path: + self.load_wavernn(config.wavernn_lib_path, config.wavernn_path, config.wavernn_file, config.wavernn_config, config.use_cuda) + + def load_tts(self, model_path, model_file, model_config, use_cuda): + tts_config = os.path.join(model_path, model_config) + self.model_file = os.path.join(model_path, model_file) + print(" > Loading TTS model ...") + print(" | > model config: ", tts_config) + print(" | > model file: ", model_file) + self.tts_config = load_config(tts_config) + self.use_phonemes = self.tts_config.use_phonemes + self.ap = AudioProcessor(**self.tts_config.audio) + if self.use_phonemes: + self.input_size = len(phonemes) + self.input_adapter = lambda sen: phoneme_to_sequence(sen, [self.tts_config.text_cleaner], self.tts_config.phoneme_language, self.tts_config.enable_eos_bos_chars) + else: + self.input_size = len(symbols) + self.input_adapter = lambda sen: text_to_sequence(sen, [self.tts_config.text_cleaner]) + self.tts_model = setup_model(self.input_size, self.tts_config) # load model state if use_cuda: cp = torch.load(self.model_file) else: - cp = torch.load( - self.model_file, map_location=lambda storage, loc: storage) + cp = torch.load(self.model_file, map_location=lambda storage, loc: storage) # load the model - self.model.load_state_dict(cp['model']) + self.tts_model.load_state_dict(cp['model']) if use_cuda: - self.model.cuda() - self.model.eval() + self.tts_model.cuda() + self.tts_model.eval() + + def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): + sys.path.append(lib_path) # set this if TTS is not installed globally + from WaveRNN.models.wavernn import Model + wavernn_config = os.path.join(model_path, model_config) + model_file = os.path.join(model_path, model_file) + print(" > Loading WaveRNN model ...") + print(" | > model config: ", wavernn_config) + print(" | > model file: ", model_file) + self.wavernn_config = load_config(wavernn_config) + self.wavernn = Model( + rnn_dims=512, + fc_dims=512, + mode=self.wavernn_config.mode, + pad=2, + upsample_factors=self.wavernn_config.upsample_factors, # set this depending on dataset + feat_dims=80, + compute_dims=128, + res_out_dims=128, + res_blocks=10, + hop_length=self.ap.hop_length, + sample_rate=self.ap.sample_rate, + ).cuda() + + check = torch.load(model_file) + self.wavernn.load_state_dict(check['model']) + if use_cuda: + self.wavernn.cuda() + self.wavernn.eval() def save_wav(self, wav, path): # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) @@ -42,25 +84,35 @@ class Synthesizer(object): self.ap.save_wav(wav, path) def tts(self, text): - text_cleaner = [self.config.text_cleaner] wavs = [] for sen in text.split('.'): if len(sen) < 3: continue sen = sen.strip() - sen += '.' print(sen) sen = sen.strip() - seq = np.array(text_to_sequence(sen, text_cleaner)) + + seq = np.array(self.input_adapter(sen)) + text_hat = sequence_to_phoneme(seq) + print(text_hat) + chars_var = torch.from_numpy(seq).unsqueeze(0).long() + if self.use_cuda: chars_var = chars_var.cuda() - mel_out, linear_out, alignments, stop_tokens = self.model.forward( + decoder_out, postnet_out, alignments, stop_tokens = self.tts_model.inference( chars_var) - linear_out = linear_out[0].data.cpu().numpy() - wav = self.ap.inv_spectrogram(linear_out.T) - out = io.BytesIO() + postnet_out = postnet_out[0].data.cpu().numpy() + if self.tts_config.model == "Tacotron": + wav = self.ap.inv_spectrogram(postnet_out.T) + elif self.tts_config.model == "Tacotron2": + if self.wavernn: + wav = self.wavernn.generate(torch.FloatTensor(postnet_out.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) + else: + wav = self.ap.inv_mel_spectrogram(postnet_out.T) wavs += list(wav) wavs += [0] * 10000 + + out = io.BytesIO() self.save_wav(wavs, out) - return out \ No newline at end of file + return out diff --git a/server/templates/index.html b/server/templates/index.html index 1186db53..b7f83c30 100644 --- a/server/templates/index.html +++ b/server/templates/index.html @@ -56,11 +56,10 @@
"work-in-progress"
+