mirror of https://github.com/coqui-ai/TTS.git
Update test_run in wavernn and wavegrad
This commit is contained in:
parent
be3a03126a
commit
20a677c623
|
@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder):
|
|||
) -> None:
|
||||
pass
|
||||
|
||||
def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument
|
||||
def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
ap = assets["audio_processor"]
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for sample in samples:
|
||||
x = sample[0]
|
||||
x = x[None, :, :].to(next(self.parameters()).device)
|
||||
|
@ -307,12 +308,12 @@ class Wavegrad(BaseVocoder):
|
|||
return {"input": m, "waveform": y}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
|
||||
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
|
||||
):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
items=samples,
|
||||
seq_len=self.config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=self.config.pad_short,
|
||||
|
|
|
@ -568,12 +568,13 @@ class Wavernn(BaseVocoder):
|
|||
return self.train_step(batch, criterion)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(
|
||||
self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument
|
||||
def test(
|
||||
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
ap = assets["audio_processor"]
|
||||
figures = {}
|
||||
audios = {}
|
||||
samples = test_loader.dataset.load_test_samples(1)
|
||||
for idx, sample in enumerate(samples):
|
||||
x = torch.FloatTensor(sample[0])
|
||||
x = x.to(next(self.parameters()).device)
|
||||
|
@ -600,14 +601,14 @@ class Wavernn(BaseVocoder):
|
|||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: True,
|
||||
data_items: List,
|
||||
samples: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
):
|
||||
ap = assets["audio_processor"]
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
items=samples,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_args.pad,
|
||||
|
|
Loading…
Reference in New Issue