Update test_run in wavernn and wavegrad

This commit is contained in:
Eren Gölge 2022-02-20 11:36:27 +01:00
parent be3a03126a
commit 20a677c623
2 changed files with 9 additions and 7 deletions

View File

@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder):
) -> None:
pass
def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument
def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference
ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
samples = test_loader.dataset.load_test_samples(1)
for sample in samples:
x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device)
@ -307,12 +308,12 @@ class Wavegrad(BaseVocoder):
return {"input": m, "waveform": y}
def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int
self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
):
ap = assets["audio_processor"]
dataset = WaveGradDataset(
ap=ap,
items=data_items,
items=samples,
seq_len=self.config.seq_len,
hop_len=ap.hop_length,
pad_short=self.config.pad_short,

View File

@ -568,12 +568,13 @@ class Wavernn(BaseVocoder):
return self.train_step(batch, criterion)
@torch.no_grad()
def test_run(
self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument
def test(
self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]:
ap = assets["audio_processor"]
figures = {}
audios = {}
samples = test_loader.dataset.load_test_samples(1)
for idx, sample in enumerate(samples):
x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device)
@ -600,14 +601,14 @@ class Wavernn(BaseVocoder):
config: Coqpit,
assets: Dict,
is_eval: True,
data_items: List,
samples: List,
verbose: bool,
num_gpus: int,
):
ap = assets["audio_processor"]
dataset = WaveRNNDataset(
ap=ap,
items=data_items,
items=samples,
seq_len=config.seq_len,
hop_len=ap.hop_length,
pad=config.model_args.pad,