Changes for review

This commit is contained in:
WeberJulian 2021-07-15 11:38:45 +02:00
parent c79a82ed07
commit 25832eb97b
4 changed files with 9 additions and 13 deletions

View File

@ -769,9 +769,9 @@ class Trainer:
return None # TODO: Fix inference on WaveGrad return None # TODO: Fix inference on WaveGrad
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda) figures, audios = self.model.test_run(self.ap, samples, None)
else: else:
figures, audios = self.model.test_run(self.ap, self.use_cuda) figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)
return None return None

View File

@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def test_run(self, ap, use_cuda) -> Tuple[Dict, Dict]: def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -218,7 +218,7 @@ class BaseTTS(BaseModel):
self, self,
sen, sen,
self.config, self.config,
use_cuda, "cuda" in str(next(self.parameters()).device),
ap, ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],

View File

@ -261,9 +261,7 @@ class Wavegrad(BaseModel):
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
return None, None return None, None
def test_run( def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda
): # pylint: disable=unused-argument
# setup noise schedule and inference # setup noise schedule and inference
noise_schedule = self.config["test_noise_schedule"] noise_schedule = self.config["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
@ -271,8 +269,7 @@ class Wavegrad(BaseModel):
for sample in samples: for sample in samples:
sample = self.format_batch(sample) sample = self.format_batch(sample)
x = sample["input"] x = sample["input"]
if use_cuda: x = x.to(next(self.parameters()).device)
x = x.cuda()
y = sample["waveform"] y = sample["waveform"]
# compute voice # compute voice
y_pred = self.inference(x) y_pred = self.inference(x)

View File

@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
with torch.no_grad(): with torch.no_grad():
if isinstance(mels, np.ndarray): if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels) mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
if mels.ndim == 2: if mels.ndim == 2:
mels = mels.unsqueeze(0) mels = mels.unsqueeze(0)
@ -571,14 +571,13 @@ class Wavernn(BaseVocoder):
@torch.no_grad() @torch.no_grad()
def test_run( def test_run(
self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
figures = {} figures = {}
audios = {} audios = {}
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
x = torch.FloatTensor(sample[0]) x = torch.FloatTensor(sample[0])
if use_cuda: x = x.to(next(self.parameters()).device)
x = x.cuda()
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat) x_hat = ap.melspectrogram(y_hat)
figures.update( figures.update(