From 6c60c182b5ff7f12eca05130b0fb897f12e799ad Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 8 Jul 2020 10:20:31 +0200 Subject: [PATCH] remove PWGAN suppoert on server and use only native vocoder implementations. Reformatting remove extra lines --- config.json | 2 +- models/tacotron2.py | 11 +++++------ server/conf.json | 4 +++- server/synthesizer.py | 42 ++++++++++++++---------------------------- 4 files changed, 23 insertions(+), 36 deletions(-) diff --git a/config.json b/config.json index 77a223dc..5aa95448 100644 --- a/config.json +++ b/config.json @@ -86,7 +86,7 @@ "prenet_type": "bn", // "original" or "bn". "prenet_dropout": false, // enable/disable dropout at prenet. - // ATTENTION + // TACOTRON ATTENTION "attention_type": "original", // 'original' or 'graves' "attention_heads": 4, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. diff --git a/models/tacotron2.py b/models/tacotron2.py index d9570f8d..4a22b7fa 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -1,5 +1,3 @@ -from math import sqrt - import torch from torch import nn @@ -65,10 +63,11 @@ class Tacotron2(TacotronAbstract): self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: - self.coarse_decoder = Decoder(decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, proj_speaker_dim) + self.coarse_decoder = Decoder( + decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, + attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, + trans_agent, forward_attn_mask, location_attn, attn_K, + separate_stopnet, proj_speaker_dim) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): diff --git a/server/conf.json b/server/conf.json index c8861cd1..00045365 100644 --- a/server/conf.json +++ b/server/conf.json @@ -3,11 +3,13 @@ "tts_file":"best_model.pth.tar", // tts checkpoint file "tts_config":"config.json", // tts config.json file "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. + "vocoder_config":null, + "vocoder_file": null, "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. "wavernn_path":null, // wavernn model root path "wavernn_file":null, // wavernn checkpoint file name "wavernn_config": null, // wavernn config file - "is_wavernn_batched":true, + "is_wavernn_batched":true, "port": 5002, "use_cuda": true, "debug": true diff --git a/server/synthesizer.py b/server/synthesizer.py index d85bbebc..b18d73ac 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -29,21 +29,18 @@ websites = r"[.](com|net|org|io|gov)" class Synthesizer(object): def __init__(self, config): self.wavernn = None - self.pwgan = None + self.vocoder_model = None self.config = config self.use_cuda = self.config.use_cuda if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.config.use_cuda) - if self.config.vocoder_checkpoint: + if self.config.vocoder_file: self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda) if self.config.wavernn_lib_path: self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, self.config.wavernn_config, self.config.use_cuda) - if self.config.pwgan_file: - self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file, - self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): # pylint: disable=global-statement @@ -129,27 +126,6 @@ class Synthesizer(object): self.wavernn.cuda() self.wavernn.eval() - def load_pwgan(self, lib_path, model_file, model_config, use_cuda): - if lib_path: - # set this if ParallelWaveGAN is not installed globally - sys.path.append(lib_path) - try: - #pylint: disable=import-outside-toplevel - from parallel_wavegan.models import ParallelWaveGANGenerator - except ImportError as e: - raise RuntimeError(f"cannot import parallel-wavegan, either install it or set its directory using the --pwgan_lib_path command line argument: {e}") - print(" > Loading PWGAN model ...") - print(" | > model config: ", model_config) - print(" | > model file: ", model_file) - with open(model_config) as f: - self.pwgan_config = yaml.load(f, Loader=yaml.Loader) - self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"]) - self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"]) - self.pwgan.remove_weight_norm() - if use_cuda: - self.pwgan.cuda() - self.pwgan.eval() - def save_wav(self, wav, path): # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav = np.array(wav) @@ -202,9 +178,9 @@ class Synthesizer(object): inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda) inputs = inputs.unsqueeze(0) # synthesize voice - decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None) - # convert outputs to numpy + _, postnet_output, _, _ = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None) if self.vocoder_model: + # use native vocoder model vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0) wav = self.vocoder_model.inference(vocoder_input) if self.use_cuda: @@ -213,6 +189,7 @@ class Synthesizer(object): wav = wav.numpy() wav = wav.flatten() elif self.wavernn: + # use 3rd paty wavernn vocoder_input = None if self.tts_config.model == "Tacotron": vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0) @@ -221,6 +198,15 @@ class Synthesizer(object): if self.use_cuda: vocoder_input.cuda() wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) + else: + # use GL + if self.use_cuda: + postnet_output = postnet_output[0].cpu() + else: + postnet_output = postnet_output[0] + postnet_output = postnet_output.numpy() + wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) + # trim silence wav = trim_silence(wav, self.ap)