diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 95aa3cd2..02c28c23 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -307,7 +307,9 @@ class Wavegrad(BaseVocoder): y = y.unsqueeze(1) return {"input": m, "waveform": y} - def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): + def get_data_loader( + self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int + ): ap = assets["audio_processor"] dataset = WaveGradDataset( ap=ap,