Fix usage of abstract class in vocoders

This commit is contained in:
Eren Gölge 2021-08-30 08:09:43 +00:00
parent 18da8f5dbd
commit 9c86f1ac68
1 changed files with 10 additions and 5 deletions

View File

@ -9,10 +9,10 @@ from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.base_vocoder import BaseVocoder
from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.utils.generic_utils import plot_results
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
)
class Wavegrad(BaseModel):
class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss}
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def train_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None
@torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
def eval_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def get_criterion(self):
@staticmethod
def get_criterion():
return torch.nn.L1Loss()
@staticmethod