From 5094499eba440efd41031ad8d7739c4b49c6045b Mon Sep 17 00:00:00 2001 From: vanIvan Date: Tue, 26 Jul 2022 16:05:11 +0300 Subject: [PATCH] Fix & update WaveRNN vocoder model (#1749) * Fixes KeyError bug. Adding logging to dashboard. * Make pep8 compliant * Make style compliant * Still fixing style --- TTS/vocoder/models/wavernn.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 6686db45..e0a25e32 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -233,6 +233,7 @@ class Wavernn(BaseVocoder): else: raise RuntimeError("Unknown model mode value - ", self.args.mode) + self.ap = AudioProcessor(**config.audio.to_dict()) self.aux_dims = self.args.res_out_dims // 4 if self.args.use_upsample_net: @@ -571,7 +572,7 @@ class Wavernn(BaseVocoder): def test( self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument ) -> Tuple[Dict, Dict]: - ap = assets["audio_processor"] + ap = self.ap figures = {} audios = {} samples = test_loader.dataset.load_test_samples(1) @@ -587,8 +588,16 @@ class Wavernn(BaseVocoder): } ) audios.update({f"test_{idx}/audio": y_hat}) + # audios.update({f"real_{idx}/audio": y_hat}) return figures, audios + def test_log( + self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + figures, audios = outputs + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + @staticmethod def format_batch(batch: Dict) -> Dict: waveform = batch[0] @@ -605,7 +614,7 @@ class Wavernn(BaseVocoder): verbose: bool, num_gpus: int, ): - ap = assets["audio_processor"] + ap = self.ap dataset = WaveRNNDataset( ap=ap, items=samples,