diff --git a/server/server.py b/server/server.py index 6af119bf..705937e2 100644 --- a/server/server.py +++ b/server/server.py @@ -18,9 +18,9 @@ def create_argparser(): parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') - parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.') - parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.') + parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') @@ -29,28 +29,35 @@ def create_argparser(): synthesizer = None -embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') -config_file = os.path.join(embedded_model_folder, 'config.json') +embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -# Default options with embedded model files -if os.path.isfile(checkpoint_file): - default_tts_checkpoint = checkpoint_file -else: - default_tts_checkpoint = None +embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') +tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') +tts_config_file = os.path.join(embedded_tts_folder, 'config.json') -if os.path.isfile(config_file): - default_tts_config = config_file -else: - default_tts_config = None +embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') +wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') +wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') + +embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan') +pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl') +pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml') args = create_argparser().parse_args() -# If these were not specified in the CLI args, use default values -if not args.tts_checkpoint: - args.tts_checkpoint = default_tts_checkpoint -if not args.tts_config: - args.tts_config = default_tts_config +# If these were not specified in the CLI args, use default values with embedded model files +if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.tts_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file +if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): + args.wavernn_file = wavernn_checkpoint_file +if not args.wavernn_config and os.path.isfile(wavernn_config_file): + args.wavernn_config = wavernn_config_file +if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file): + args.pwgan_file = pwgan_checkpoint_file +if not args.pwgan_config and os.path.isfile(pwgan_config_file): + args.pwgan_config = pwgan_config_file synthesizer = Synthesizer(args) diff --git a/server/synthesizer.py b/server/synthesizer.py index 75fd4e76..fcdc8787 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -121,8 +121,9 @@ class Synthesizer(object): wav = np.array(wav) self.ap.save_wav(wav, path) - def split_into_sentences(self, text): - text = " " + text + " " + @staticmethod + def split_into_sentences(text): + text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) text = re.sub(websites, "\\1", text) @@ -149,15 +150,13 @@ class Synthesizer(object): text = text.replace("", ".") sentences = text.split("") sentences = sentences[:-1] - sentences = [s.strip() for s in sentences] + sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences return sentences def tts(self, text): wavs = [] sens = self.split_into_sentences(text) print(sens) - if not sens: - sens = [text+'.'] for sen in sens: # preprocess the given text inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) @@ -168,9 +167,16 @@ class Synthesizer(object): postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, alignments) + if self.pwgan: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length) if self.wavernn: - postnet_output = postnet_output[0].data.cpu().numpy() - wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) else: wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) # trim silence diff --git a/setup.py b/setup.py index 63782800..f92dac8a 100644 --- a/setup.py +++ b/setup.py @@ -61,10 +61,11 @@ package_data = ['server/templates/*'] if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: print('Embedding model in wheel file...') model_dir = os.path.join('server', 'model') - os.makedirs(model_dir, exist_ok=True) - embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar') + tts_dir = os.path.join(model_dir, 'tts') + os.makedirs(tts_dir, exist_ok=True) + embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar') shutil.copy(args.checkpoint, embedded_checkpoint_path) - embedded_config_path = os.path.join(model_dir, 'config.json') + embedded_config_path = os.path.join(tts_dir, 'config.json') shutil.copy(args.model_config, embedded_config_path) package_data.extend([embedded_checkpoint_path, embedded_config_path]) diff --git a/tests/test_server_package.sh b/tests/test_server_package.sh index 01e42843..9fe5e8b1 100755 --- a/tests/test_server_package.sh +++ b/tests/test_server_package.sh @@ -11,7 +11,7 @@ source /tmp/venv/bin/activate pip install --quiet --upgrade pip setuptools wheel rm -f dist/*.whl -python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json +python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json pip install --quiet dist/TTS*.whl python -m TTS.server.server &