From 9c86f1ac6858ac62ce57cba66c11686f8c2f558c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 30 Aug 2021 08:09:43 +0000 Subject: [PATCH] Fix usage of abstract class in vocoders --- TTS/vocoder/models/wavegrad.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 5dc878d7..2a76baa5 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -9,10 +9,10 @@ from torch.nn.utils import weight_norm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.model import BaseModel from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec from TTS.utils.trainer_utils import get_optimizer, get_scheduler +from TTS.vocoder.base_vocoder import BaseVocoder from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.utils.generic_utils import plot_results @@ -33,7 +33,7 @@ class WavegradArgs(Coqpit): ) -class Wavegrad(BaseModel): +class Wavegrad(BaseVocoder): """🐸 🌊 WaveGrad 🌊 model. Paper - https://arxiv.org/abs/2009.00713 @@ -257,14 +257,18 @@ class Wavegrad(BaseModel): loss = criterion(noise, noise_hat) return {"model_output": noise_hat}, {"loss": loss} - def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def train_log( # pylint: disable=no-self-use + self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: return None, None @torch.no_grad() def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: return self.train_step(batch, criterion) - def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + def eval_log( # pylint: disable=no-self-use + self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: return None, None def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument @@ -291,7 +295,8 @@ class Wavegrad(BaseModel): def get_scheduler(self, optimizer): return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) - def get_criterion(self): + @staticmethod + def get_criterion(): return torch.nn.L1Loss() @staticmethod