Update test_run in wavernn and wavegrad

This commit is contained in:
Eren Gölge 2022-02-20 11:36:27 +01:00
parent be3a03126a
commit 20a677c623
2 changed files with 9 additions and 7 deletions

View File

@ -270,12 +270,13 @@ class Wavegrad(BaseVocoder):
) -> None: ) -> None:
pass pass
def test_run(self, assets: Dict, samples: List[Dict], outputs: Dict): # pylint: disable=unused-argument def test(self, assets: Dict, test_loader:"DataLoader", outputs=None): # pylint: disable=unused-argument
# setup noise schedule and inference # setup noise schedule and inference
ap = assets["audio_processor"] ap = assets["audio_processor"]
noise_schedule = self.config["test_noise_schedule"] noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas) self.compute_noise_level(betas)
samples = test_loader.dataset.load_test_samples(1)
for sample in samples: for sample in samples:
x = sample[0] x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device) x = x[None, :, :].to(next(self.parameters()).device)
@ -307,12 +308,12 @@ class Wavegrad(BaseVocoder):
return {"input": m, "waveform": y} return {"input": m, "waveform": y}
def get_data_loader( def get_data_loader(
self, config: Coqpit, assets: Dict, is_eval: True, data_items: List, verbose: bool, num_gpus: int self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int
): ):
ap = assets["audio_processor"] ap = assets["audio_processor"]
dataset = WaveGradDataset( dataset = WaveGradDataset(
ap=ap, ap=ap,
items=data_items, items=samples,
seq_len=self.config.seq_len, seq_len=self.config.seq_len,
hop_len=ap.hop_length, hop_len=ap.hop_length,
pad_short=self.config.pad_short, pad_short=self.config.pad_short,

View File

@ -568,12 +568,13 @@ class Wavernn(BaseVocoder):
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
@torch.no_grad() @torch.no_grad()
def test_run( def test(
self, assets: Dict, samples: List[Dict], output: Dict # pylint: disable=unused-argument self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
ap = assets["audio_processor"] ap = assets["audio_processor"]
figures = {} figures = {}
audios = {} audios = {}
samples = test_loader.dataset.load_test_samples(1)
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
x = torch.FloatTensor(sample[0]) x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device) x = x.to(next(self.parameters()).device)
@ -600,14 +601,14 @@ class Wavernn(BaseVocoder):
config: Coqpit, config: Coqpit,
assets: Dict, assets: Dict,
is_eval: True, is_eval: True,
data_items: List, samples: List,
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
): ):
ap = assets["audio_processor"] ap = assets["audio_processor"]
dataset = WaveRNNDataset( dataset = WaveRNNDataset(
ap=ap, ap=ap,
items=data_items, items=samples,
seq_len=config.seq_len, seq_len=config.seq_len,
hop_len=ap.hop_length, hop_len=ap.hop_length,
pad=config.model_args.pad, pad=config.model_args.pad,