From 02e6d0538272f589d6c3c290b81575b7bd866991 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 13 Feb 2020 15:49:46 +0100 Subject: [PATCH 1/4] Use PWGAN if available in Synthesizer.tts --- server/synthesizer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 75fd4e76..455bd332 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -168,9 +168,16 @@ class Synthesizer(object): postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, alignments) + if self.pwgan: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length) if self.wavernn: - postnet_output = postnet_output[0].data.cpu().numpy() - wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) else: wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) # trim silence From b539ffafc0a0c185438bab262719f4259b6c8f9f Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 13 Feb 2020 15:54:30 +0100 Subject: [PATCH 2/4] Load PWGAN/WaveRNN embedded files if present --- server/server.py | 47 +++++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/server/server.py b/server/server.py index 6af119bf..705937e2 100644 --- a/server/server.py +++ b/server/server.py @@ -18,9 +18,9 @@ def create_argparser(): parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') - parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.') - parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.') + parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') @@ -29,28 +29,35 @@ def create_argparser(): synthesizer = None -embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') -config_file = os.path.join(embedded_model_folder, 'config.json') +embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -# Default options with embedded model files -if os.path.isfile(checkpoint_file): - default_tts_checkpoint = checkpoint_file -else: - default_tts_checkpoint = None +embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') +tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') +tts_config_file = os.path.join(embedded_tts_folder, 'config.json') -if os.path.isfile(config_file): - default_tts_config = config_file -else: - default_tts_config = None +embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') +wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') +wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') + +embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan') +pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl') +pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml') args = create_argparser().parse_args() -# If these were not specified in the CLI args, use default values -if not args.tts_checkpoint: - args.tts_checkpoint = default_tts_checkpoint -if not args.tts_config: - args.tts_config = default_tts_config +# If these were not specified in the CLI args, use default values with embedded model files +if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.tts_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file +if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): + args.wavernn_file = wavernn_checkpoint_file +if not args.wavernn_config and os.path.isfile(wavernn_config_file): + args.wavernn_config = wavernn_config_file +if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file): + args.pwgan_file = pwgan_checkpoint_file +if not args.pwgan_config and os.path.isfile(pwgan_config_file): + args.pwgan_config = pwgan_config_file synthesizer = Synthesizer(args) From 995eb1bf074caae257a87f5ef54ae5f63617b227 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 13 Feb 2020 16:03:30 +0100 Subject: [PATCH 3/4] Fix bug where sometimes the second sentence disappears if it doesn't end with punctuation --- server/synthesizer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 455bd332..1082b73a 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -122,7 +122,7 @@ class Synthesizer(object): self.ap.save_wav(wav, path) def split_into_sentences(self, text): - text = " " + text + " " + text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) text = re.sub(websites, "\\1", text) @@ -149,15 +149,13 @@ class Synthesizer(object): text = text.replace("", ".") sentences = text.split("") sentences = sentences[:-1] - sentences = [s.strip() for s in sentences] + sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences return sentences def tts(self, text): wavs = [] sens = self.split_into_sentences(text) print(sens) - if not sens: - sens = [text+'.'] for sen in sens: # preprocess the given text inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) From ffd00ce295e8b68e59dccda99bc467823a62940d Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 13 Feb 2020 17:30:41 +0100 Subject: [PATCH 4/4] Fix linter and server package test --- server/synthesizer.py | 3 ++- setup.py | 7 ++++--- tests/test_server_package.sh | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/synthesizer.py b/server/synthesizer.py index 1082b73a..fcdc8787 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -121,7 +121,8 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path) - def split_into_sentences(self, text): + @staticmethod + def split_into_sentences(text): text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) diff --git a/setup.py b/setup.py index 63782800..f92dac8a 100644 --- a/setup.py +++ b/setup.py @@ -61,10 +61,11 @@ package_data = ['server/templates/*'] if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: print('Embedding model in wheel file...') model_dir = os.path.join('server', 'model') - os.makedirs(model_dir, exist_ok=True) - embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar') + tts_dir = os.path.join(model_dir, 'tts') + os.makedirs(tts_dir, exist_ok=True) + embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar') shutil.copy(args.checkpoint, embedded_checkpoint_path) - embedded_config_path = os.path.join(model_dir, 'config.json') + embedded_config_path = os.path.join(tts_dir, 'config.json') shutil.copy(args.model_config, embedded_config_path) package_data.extend([embedded_checkpoint_path, embedded_config_path]) diff --git a/tests/test_server_package.sh b/tests/test_server_package.sh index 01e42843..9fe5e8b1 100755 --- a/tests/test_server_package.sh +++ b/tests/test_server_package.sh @@ -11,7 +11,7 @@ source /tmp/venv/bin/activate pip install --quiet --upgrade pip setuptools wheel rm -f dist/*.whl -python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json +python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json pip install --quiet dist/TTS*.whl python -m TTS.server.server &