remove PWGAN suppoert on server and use only native vocoder implementations. Reformatting remove extra lines

This commit is contained in:
erogol 2020-07-08 10:20:31 +02:00
parent b1935c97fa
commit 6c60c182b5
4 changed files with 23 additions and 36 deletions

View File

@ -86,7 +86,7 @@
"prenet_type": "bn", // "original" or "bn". "prenet_type": "bn", // "original" or "bn".
"prenet_dropout": false, // enable/disable dropout at prenet. "prenet_dropout": false, // enable/disable dropout at prenet.
// ATTENTION // TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves' "attention_type": "original", // 'original' or 'graves'
"attention_heads": 4, // number of attention heads (only for 'graves') "attention_heads": 4, // number of attention heads (only for 'graves')
"attention_norm": "sigmoid", // softmax or sigmoid. "attention_norm": "sigmoid", // softmax or sigmoid.

View File

@ -1,5 +1,3 @@
from math import sqrt
import torch import torch
from torch import nn from torch import nn
@ -65,10 +63,11 @@ class Tacotron2(TacotronAbstract):
self._init_backward_decoder() self._init_backward_decoder()
# setup DDC # setup DDC
if self.double_decoder_consistency: if self.double_decoder_consistency:
self.coarse_decoder = Decoder(decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, self.coarse_decoder = Decoder(
attn_norm, prenet_type, prenet_dropout, decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
forward_attn, trans_agent, forward_attn_mask, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
location_attn, attn_K, separate_stopnet, proj_speaker_dim) trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet, proj_speaker_dim)
@staticmethod @staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):

View File

@ -3,6 +3,8 @@
"tts_file":"best_model.pth.tar", // tts checkpoint file "tts_file":"best_model.pth.tar", // tts checkpoint file
"tts_config":"config.json", // tts config.json file "tts_config":"config.json", // tts config.json file
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
"vocoder_config":null,
"vocoder_file": null,
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
"wavernn_path":null, // wavernn model root path "wavernn_path":null, // wavernn model root path
"wavernn_file":null, // wavernn checkpoint file name "wavernn_file":null, // wavernn checkpoint file name

View File

@ -29,21 +29,18 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object): class Synthesizer(object):
def __init__(self, config): def __init__(self, config):
self.wavernn = None self.wavernn = None
self.pwgan = None self.vocoder_model = None
self.config = config self.config = config
self.use_cuda = self.config.use_cuda self.use_cuda = self.config.use_cuda
if self.use_cuda: if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine." assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda) self.config.use_cuda)
if self.config.vocoder_checkpoint: if self.config.vocoder_file:
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda) self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
if self.config.wavernn_lib_path: if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_config, self.config.use_cuda) self.config.wavernn_config, self.config.use_cuda)
if self.config.pwgan_file:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda): def load_tts(self, tts_checkpoint, tts_config, use_cuda):
# pylint: disable=global-statement # pylint: disable=global-statement
@ -129,27 +126,6 @@ class Synthesizer(object):
self.wavernn.cuda() self.wavernn.cuda()
self.wavernn.eval() self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
if lib_path:
# set this if ParallelWaveGAN is not installed globally
sys.path.append(lib_path)
try:
#pylint: disable=import-outside-toplevel
from parallel_wavegan.models import ParallelWaveGANGenerator
except ImportError as e:
raise RuntimeError(f"cannot import parallel-wavegan, either install it or set its directory using the --pwgan_lib_path command line argument: {e}")
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path): def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav))) # wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav) wav = np.array(wav)
@ -202,9 +178,9 @@ class Synthesizer(object):
inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
inputs = inputs.unsqueeze(0) inputs = inputs.unsqueeze(0)
# synthesize voice # synthesize voice
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None) _, postnet_output, _, _ = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
# convert outputs to numpy
if self.vocoder_model: if self.vocoder_model:
# use native vocoder model
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
wav = self.vocoder_model.inference(vocoder_input) wav = self.vocoder_model.inference(vocoder_input)
if self.use_cuda: if self.use_cuda:
@ -213,6 +189,7 @@ class Synthesizer(object):
wav = wav.numpy() wav = wav.numpy()
wav = wav.flatten() wav = wav.flatten()
elif self.wavernn: elif self.wavernn:
# use 3rd paty wavernn
vocoder_input = None vocoder_input = None
if self.tts_config.model == "Tacotron": if self.tts_config.model == "Tacotron":
vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0) vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0)
@ -221,6 +198,15 @@ class Synthesizer(object):
if self.use_cuda: if self.use_cuda:
vocoder_input.cuda() vocoder_input.cuda()
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
else:
# use GL
if self.use_cuda:
postnet_output = postnet_output[0].cpu()
else:
postnet_output = postnet_output[0]
postnet_output = postnet_output.numpy()
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
# trim silence # trim silence
wav = trim_silence(wav, self.ap) wav = trim_silence(wav, self.ap)