Merge pull request #667 from coqui-ai/fix-test-sentences

Fix test runs and wavegrad test_run
This commit is contained in:
Eren Gölge 2021-07-16 14:11:40 +02:00 committed by GitHub
commit 9bb7f31f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 47 additions and 24 deletions

View File

@ -764,11 +764,11 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader.load_test_samples):
samples = self.eval_loader.load_test_samples(1)
figures, audios = self.model.test_run(samples)
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)
else:
figures, audios = self.model.test_run()
figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
@ -790,7 +790,7 @@ class Trainer:
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values

View File

@ -70,7 +70,7 @@ class BaseTTS(BaseModel):
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
pass
return {"speaker_id": None, "style_wav": None, "d_vector": None}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
)
return loader
def test_run(self) -> Tuple[Dict, Dict]:
def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs()
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
self.model,
self,
sen,
self.config,
self.use_cuda,
self.ap,
"cuda" in str(next(self.parameters()).device),
ap,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
return test_figures, test_audios

View File

@ -113,7 +113,7 @@ class GlowTTS(BaseTTS):
@staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
""" Compute and format the mode outputs with the given alignment map"""
"""Compute and format the mode outputs with the given alignment map"""
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']

View File

@ -2,6 +2,7 @@ import glob
import os
import random
from multiprocessing import Manager
from typing import List, Tuple
import numpy as np
import torch
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
item = self.load_item(idx)
return item
def load_test_samples(self, num_samples):
def load_test_samples(self, num_samples: int) -> List[Tuple]:
"""Return test samples.
Args:
num_samples (int): Number of samples to return.
Returns:
List[Tuple]: melspectorgram and audio.
Shapes:
- melspectrogram (Tensor): :math:`[C, T]`
- audio (Tensor): :math:`[T_audio]`
"""
samples = []
return_segments = self.return_segments
self.return_segments = False

View File

@ -31,7 +31,7 @@ def setup_model(config: Coqpit):
def setup_generator(c):
""" TODO: use config object as arguments"""
"""TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model))
@ -94,7 +94,7 @@ def setup_generator(c):
def setup_discriminator(c):
""" TODO: use config objekt as arguments"""
"""TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model))
if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")

View File

@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
@torch.no_grad()
def inference(self, x, y_n=None):
"""x: B x D X T"""
"""
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
@ -267,8 +272,10 @@ class Wavegrad(BaseModel):
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas)
for sample in samples:
x = sample["input"]
y = sample["waveform"]
x = sample[0]
x = x[None, :, :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice
y_pred = self.inference(x)
# compute spectrograms

View File

@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
with torch.no_grad():
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).type_as(mels)
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
if mels.ndim == 2:
mels = mels.unsqueeze(0)
@ -576,7 +576,8 @@ class Wavernn(BaseVocoder):
figures = {}
audios = {}
for idx, sample in enumerate(samples):
x = sample["input"]
x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device)
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat)
figures.update(
@ -585,7 +586,7 @@ class Wavernn(BaseVocoder):
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
}
)
audios.update({f"test_{idx}/audio", y_hat})
audios.update({f"test_{idx}/audio": y_hat})
return figures, audios
@staticmethod

View File

@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
def run_test_train():
command = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
@ -17,6 +18,7 @@ def run_test_train():
)
run_cli(command)
config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")