mirror of https://github.com/coqui-ai/TTS.git
Fix tests
This commit is contained in:
parent
32974dd6a9
commit
7d92b30946
|
@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
|
|||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.vocoder.models.wavegrad import Wavegrad
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.callbacks import TrainerCallback
|
||||
|
@ -764,11 +765,13 @@ class Trainer:
|
|||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.eval_loader, "load_test_samples"):
|
||||
samples = self.eval_loader.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(samples)
|
||||
if isinstance(self.model, Wavegrad):
|
||||
return None # TODO: Fix inference on WaveGrad
|
||||
elif hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None, self.use_cuda)
|
||||
else:
|
||||
figures, audios = self.model.test_run(use_cuda=self.use_cuda, ap=self.ap)
|
||||
figures, audios = self.model.test_run(self.ap, self.use_cuda)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ class BaseTTS(BaseModel):
|
|||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
pass
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
return loader
|
||||
|
||||
def test_run(self, use_cuda=True, ap=None) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, ap, use_cuda) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
|
|
@ -261,13 +261,16 @@ class Wavegrad(BaseModel):
|
|||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
return None, None
|
||||
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict): # pylint: disable=unused-argument
|
||||
def test_run(self, ap: AudioProcessor, samples: List[Dict], ouputs: Dict, use_cuda): # pylint: disable=unused-argument
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = self.config["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
for sample in samples:
|
||||
sample = self.format_batch(sample)
|
||||
x = sample["input"]
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
y = sample["waveform"]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
|
|
|
@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
|
|||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).type_as(mels)
|
||||
mels = torch.FloatTensor(mels)
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
|
@ -571,12 +571,14 @@ class Wavernn(BaseVocoder):
|
|||
|
||||
@torch.no_grad()
|
||||
def test_run(
|
||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict # pylint: disable=unused-argument
|
||||
self, ap: AudioProcessor, samples: List[Dict], output: Dict, use_cuda # pylint: disable=unused-argument
|
||||
) -> Tuple[Dict, Dict]:
|
||||
figures = {}
|
||||
audios = {}
|
||||
for idx, sample in enumerate(samples):
|
||||
x = sample["input"]
|
||||
x = torch.FloatTensor(sample[0])
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
|
@ -585,7 +587,7 @@ class Wavernn(BaseVocoder):
|
|||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio", y_hat})
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue