Merge pull request #344 from mozilla/server-pwgan

Adapting server to ParallelWaveGAN
This commit is contained in:
Eren Gölge 2020-02-13 15:46:04 +01:00 committed by GitHub
commit 2a78725b68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 20 deletions

View File

@ -6,6 +6,10 @@ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple
#### Development server: #### Development server:
##### Using server.py
If you have the environment set already for TTS, then you can directly call ```setup.py```.
##### Using .whl
1. apt-get install -y espeak libsndfile1 python3-venv 1. apt-get install -y espeak libsndfile1 python3-venv
2. python3 -m venv /tmp/venv 2. python3 -m venv /tmp/venv
3. source /tmp/venv/bin/activate 3. source /tmp/venv/bin/activate

View File

@ -14,10 +14,13 @@ def create_argparser():
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.')
parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.')
parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')

View File

@ -1,17 +1,17 @@
import io import io
import os import re
import sys
import numpy as np import numpy as np
import torch import torch
import sys import yaml
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.generic_utils import load_config, setup_model
from TTS.utils.text import phonemes, symbols
from TTS.utils.speakers import load_speaker_mapping from TTS.utils.speakers import load_speaker_mapping
from TTS.utils.synthesis import * from TTS.utils.synthesis import *
from TTS.utils.text import phonemes, symbols
import re
alphabets = r"([A-Za-z])" alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)" suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
@ -23,6 +23,7 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object): class Synthesizer(object):
def __init__(self, config): def __init__(self, config):
self.wavernn = None self.wavernn = None
self.pwgan = None
self.config = config self.config = config
self.use_cuda = self.config.use_cuda self.use_cuda = self.config.use_cuda
if self.use_cuda: if self.use_cuda:
@ -30,9 +31,11 @@ class Synthesizer(object):
self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda) self.config.use_cuda)
if self.config.wavernn_lib_path: if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_file, self.config.wavernn_config, self.config.wavernn_config, self.config.use_cuda)
self.config.use_cuda) if self.config.pwgan_lib_path:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda): def load_tts(self, tts_checkpoint, tts_config, use_cuda):
print(" > Loading TTS model ...") print(" > Loading TTS model ...")
@ -45,9 +48,9 @@ class Synthesizer(object):
self.input_size = len(phonemes) self.input_size = len(phonemes)
else: else:
self.input_size = len(symbols) self.input_size = len(symbols)
# load speakers # TODO: fix this for multi-speaker model - load speakers
if self.config.tts_speakers is not None: if self.config.tts_speakers is not None:
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
num_speakers = len(self.tts_speakers) num_speakers = len(self.tts_speakers)
else: else:
num_speakers = 0 num_speakers = 0
@ -63,16 +66,17 @@ class Synthesizer(object):
if 'r' in cp: if 'r' in cp:
self.tts_model.decoder.set_r(cp['r']) self.tts_model.decoder.set_r(cp['r'])
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
# TODO: set a function in wavernn code base for model setup and call it here. # TODO: set a function in wavernn code base for model setup and call it here.
sys.path.append(lib_path) # set this if TTS is not installed globally sys.path.append(lib_path) # set this if WaveRNN is not installed globally
#pylint: disable=import-outside-toplevel
from WaveRNN.models.wavernn import Model from WaveRNN.models.wavernn import Model
wavernn_config = os.path.join(model_path, model_config)
model_file = os.path.join(model_path, model_file)
print(" > Loading WaveRNN model ...") print(" > Loading WaveRNN model ...")
print(" | > model config: ", wavernn_config) print(" | > model config: ", model_config)
print(" | > model file: ", model_file) print(" | > model file: ", model_file)
self.wavernn_config = load_config(wavernn_config) self.wavernn_config = load_config(model_config)
# This is the default architecture we use for our models.
# You might need to update it
self.wavernn = Model( self.wavernn = Model(
rnn_dims=512, rnn_dims=512,
fc_dims=512, fc_dims=512,
@ -91,11 +95,27 @@ class Synthesizer(object):
).cuda() ).cuda()
check = torch.load(model_file) check = torch.load(model_file)
self.wavernn.load_state_dict(check['model']) self.wavernn.load_state_dict(check['model'], map_location="cpu")
if use_cuda: if use_cuda:
self.wavernn.cuda() self.wavernn.cuda()
self.wavernn.eval() self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally
#pylint: disable=import-outside-toplevel
from parallel_wavegan.models import ParallelWaveGANGenerator
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path): def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav))) # wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav) wav = np.array(wav)

View File

@ -3,9 +3,11 @@
"tts_config":"dummy_model_config.json", // tts config.json file "tts_config":"dummy_model_config.json", // tts config.json file
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
"wavernn_path": null, // wavernn model root path
"wavernn_file": null, // wavernn checkpoint file name "wavernn_file": null, // wavernn checkpoint file name
"wavernn_config": null, // wavernn config file "wavernn_config": null, // wavernn config file
"pwgan_lib_path": null,
"pwgan_file": null,
"pwgan_config": null,
"is_wavernn_batched":true, "is_wavernn_batched":true,
"port": 5002, "port": 5002,
"use_cuda": false, "use_cuda": false,