remove PWGAN suppoert on server and use only native vocoder implementations. Reformatting remove extra lines

This commit is contained in:
erogol 2020-07-08 10:20:31 +02:00
parent b1935c97fa
commit 6c60c182b5
4 changed files with 23 additions and 36 deletions

View File

@ -86,7 +86,7 @@
"prenet_type": "bn", // "original" or "bn".
"prenet_dropout": false, // enable/disable dropout at prenet.
// ATTENTION
// TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves'
"attention_heads": 4, // number of attention heads (only for 'graves')
"attention_norm": "sigmoid", // softmax or sigmoid.

View File

@ -1,5 +1,3 @@
from math import sqrt
import torch
from torch import nn
@ -65,10 +63,11 @@ class Tacotron2(TacotronAbstract):
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self.coarse_decoder = Decoder(decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
self.coarse_decoder = Decoder(
decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet, proj_speaker_dim)
@staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):

View File

@ -3,11 +3,13 @@
"tts_file":"best_model.pth.tar", // tts checkpoint file
"tts_config":"config.json", // tts config.json file
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
"vocoder_config":null,
"vocoder_file": null,
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
"wavernn_path":null, // wavernn model root path
"wavernn_file":null, // wavernn checkpoint file name
"wavernn_config": null, // wavernn config file
"is_wavernn_batched":true,
"is_wavernn_batched":true,
"port": 5002,
"use_cuda": true,
"debug": true

View File

@ -29,21 +29,18 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object):
def __init__(self, config):
self.wavernn = None
self.pwgan = None
self.vocoder_model = None
self.config = config
self.use_cuda = self.config.use_cuda
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda)
if self.config.vocoder_checkpoint:
if self.config.vocoder_file:
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_config, self.config.use_cuda)
if self.config.pwgan_file:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
# pylint: disable=global-statement
@ -129,27 +126,6 @@ class Synthesizer(object):
self.wavernn.cuda()
self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
if lib_path:
# set this if ParallelWaveGAN is not installed globally
sys.path.append(lib_path)
try:
#pylint: disable=import-outside-toplevel
from parallel_wavegan.models import ParallelWaveGANGenerator
except ImportError as e:
raise RuntimeError(f"cannot import parallel-wavegan, either install it or set its directory using the --pwgan_lib_path command line argument: {e}")
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav)
@ -202,9 +178,9 @@ class Synthesizer(object):
inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
inputs = inputs.unsqueeze(0)
# synthesize voice
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
# convert outputs to numpy
_, postnet_output, _, _ = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
if self.vocoder_model:
# use native vocoder model
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
wav = self.vocoder_model.inference(vocoder_input)
if self.use_cuda:
@ -213,6 +189,7 @@ class Synthesizer(object):
wav = wav.numpy()
wav = wav.flatten()
elif self.wavernn:
# use 3rd paty wavernn
vocoder_input = None
if self.tts_config.model == "Tacotron":
vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0)
@ -221,6 +198,15 @@ class Synthesizer(object):
if self.use_cuda:
vocoder_input.cuda()
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
else:
# use GL
if self.use_cuda:
postnet_output = postnet_output[0].cpu()
else:
postnet_output = postnet_output[0]
postnet_output = postnet_output.numpy()
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
# trim silence
wav = trim_silence(wav, self.ap)