mirror of https://github.com/coqui-ai/TTS.git
Fix usage of abstract class in vocoders
This commit is contained in:
parent
18da8f5dbd
commit
9c86f1ac68
|
@ -9,10 +9,10 @@ from torch.nn.utils import weight_norm
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
@ -33,7 +33,7 @@ class WavegradArgs(Coqpit):
|
|||
)
|
||||
|
||||
|
||||
class Wavegrad(BaseModel):
|
||||
class Wavegrad(BaseVocoder):
|
||||
"""🐸 🌊 WaveGrad 🌊 model.
|
||||
Paper - https://arxiv.org/abs/2009.00713
|
||||
|
||||
|
@ -257,14 +257,18 @@ class Wavegrad(BaseModel):
|
|||
loss = criterion(noise, noise_hat)
|
||||
return {"model_output": noise_hat}, {"loss": loss}
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
def train_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
def eval_log( # pylint: disable=no-self-use
|
||||
self, ap: AudioProcessor, batch: Dict, outputs: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
|
||||
|
@ -291,7 +295,8 @@ class Wavegrad(BaseModel):
|
|||
def get_scheduler(self, optimizer):
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
def get_criterion(self):
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return torch.nn.L1Loss()
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue