Fix usage of abstract class in vocoders

This commit is contained in:
Eren Gölge 2021-08-30 08:09:43 +00:00
parent 18da8f5dbd
commit 9c86f1ac68
1 changed files with 10 additions and 5 deletions

View File

@ -9,10 +9,10 @@ from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.utils.trainer_utils import get_optimizer, get_scheduler from TTS.utils.trainer_utils import get_optimizer, get_scheduler
from TTS.vocoder.base_vocoder import BaseVocoder
from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.datasets import WaveGradDataset
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
) )
class Wavegrad(BaseModel): class Wavegrad(BaseVocoder):
"""🐸 🌊 WaveGrad 🌊 model. """🐸 🌊 WaveGrad 🌊 model.
Paper - https://arxiv.org/abs/2009.00713 Paper - https://arxiv.org/abs/2009.00713
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
return {"model_output": noise_hat}, {"loss": loss} return {"model_output": noise_hat}, {"loss": loss}
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def train_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log( # pylint: disable=no-self-use
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
def get_scheduler(self, optimizer): def get_scheduler(self, optimizer):
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def get_criterion(self): @staticmethod
def get_criterion():
return torch.nn.L1Loss() return torch.nn.L1Loss()
@staticmethod @staticmethod