Merge pull request #667 from coqui-ai/fix-test-sentences

Fix test runs and wavegrad test_run
This commit is contained in:
Eren Gölge 2021-07-16 14:11:40 +02:00 committed by GitHub
commit 9bb7f31f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 47 additions and 24 deletions

View File

@ -764,11 +764,11 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if hasattr(self.eval_loader.load_test_samples): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(samples) figures, audios = self.model.test_run(self.ap, samples, None)
else: else:
figures, audios = self.model.test_run() figures, audios = self.model.test_run(self.ap)
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, figures) self.tb_logger.tb_test_figures(self.total_steps_done, figures)
@ -790,7 +790,7 @@ class Trainer:
self.train_epoch() self.train_epoch()
if self.config.run_eval: if self.config.run_eval:
self.eval_epoch() self.eval_epoch()
if epoch >= self.config.test_delay_epochs and self.args.rank < 0: if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
self.test_run() self.test_run()
self.c_logger.print_epoch_end( self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values

View File

@ -70,7 +70,7 @@ class BaseTTS(BaseModel):
def get_aux_input(self, **kwargs) -> Dict: def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`""" """Prepare and return `aux_input` used by `forward()`"""
pass return {"speaker_id": None, "style_wav": None, "d_vector": None}
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`. """Generic batch formatting for `TTSDataset`.
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def test_run(self) -> Tuple[Dict, Dict]: def test_run(self, ap) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour. You can override this for a different behaviour.
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self._get_aux_inputs() aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis( wav, alignment, model_outputs, _ = synthesis(
self.model, self,
sen, sen,
self.config, self.config,
self.use_cuda, "cuda" in str(next(self.parameters()).device),
self.ap, ap,
speaker_id=aux_inputs["speaker_id"], speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"], d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"], style_wav=aux_inputs["style_wav"],
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
).values() ).values()
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
return test_figures, test_audios return test_figures, test_audios

View File

@ -113,7 +113,7 @@ class GlowTTS(BaseTTS):
@staticmethod @staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask): def compute_outputs(attn, o_mean, o_log_scale, x_mask):
""" Compute and format the mode outputs with the given alignment map""" """Compute and format the mode outputs with the given alignment map"""
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2 1, 2
) # [b, t', t], [b, t, d] -> [b, d, t'] ) # [b, t', t], [b, t, d] -> [b, d, t']

View File

@ -2,6 +2,7 @@ import glob
import os import os
import random import random
from multiprocessing import Manager from multiprocessing import Manager
from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
item = self.load_item(idx) item = self.load_item(idx)
return item return item
def load_test_samples(self, num_samples): def load_test_samples(self, num_samples: int) -> List[Tuple]:
"""Return test samples.
Args:
num_samples (int): Number of samples to return.
Returns:
List[Tuple]: melspectorgram and audio.
Shapes:
- melspectrogram (Tensor): :math:`[C, T]`
- audio (Tensor): :math:`[T_audio]`
"""
samples = [] samples = []
return_segments = self.return_segments return_segments = self.return_segments
self.return_segments = False self.return_segments = False

View File

@ -31,7 +31,7 @@ def setup_model(config: Coqpit):
def setup_generator(c): def setup_generator(c):
""" TODO: use config object as arguments""" """TODO: use config object as arguments"""
print(" > Generator Model: {}".format(c.generator_model)) print(" > Generator Model: {}".format(c.generator_model))
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
MyModel = getattr(MyModel, to_camel(c.generator_model)) MyModel = getattr(MyModel, to_camel(c.generator_model))
@ -94,7 +94,7 @@ def setup_generator(c):
def setup_discriminator(c): def setup_discriminator(c):
""" TODO: use config objekt as arguments""" """TODO: use config objekt as arguments"""
print(" > Discriminator Model: {}".format(c.discriminator_model)) print(" > Discriminator Model: {}".format(c.discriminator_model))
if "parallel_wavegan" in c.discriminator_model: if "parallel_wavegan" in c.discriminator_model:
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")

View File

@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
@torch.no_grad() @torch.no_grad()
def inference(self, x, y_n=None): def inference(self, x, y_n=None):
"""x: B x D X T""" """
Shapes:
x: :math:`[B, C , T]`
y_n: :math:`[B, 1, T]`
"""
if y_n is None: if y_n is None:
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
else: else:
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
y_n = y_n.type_as(x)
sqrt_alpha_hat = self.noise_level.to(x) sqrt_alpha_hat = self.noise_level.to(x)
for n in range(len(self.alpha) - 1, -1, -1): for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
@ -267,8 +272,10 @@ class Wavegrad(BaseModel):
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
self.compute_noise_level(betas) self.compute_noise_level(betas)
for sample in samples: for sample in samples:
x = sample["input"] x = sample[0]
y = sample["waveform"] x = x[None, :, :].to(next(self.parameters()).device)
y = sample[1]
y = y[None, :]
# compute voice # compute voice
y_pred = self.inference(x) y_pred = self.inference(x)
# compute spectrograms # compute spectrograms

View File

@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
with torch.no_grad(): with torch.no_grad():
if isinstance(mels, np.ndarray): if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).type_as(mels) mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
if mels.ndim == 2: if mels.ndim == 2:
mels = mels.unsqueeze(0) mels = mels.unsqueeze(0)
@ -576,7 +576,8 @@ class Wavernn(BaseVocoder):
figures = {} figures = {}
audios = {} audios = {}
for idx, sample in enumerate(samples): for idx, sample in enumerate(samples):
x = sample["input"] x = torch.FloatTensor(sample[0])
x = x.to(next(self.parameters()).device)
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
x_hat = ap.melspectrogram(y_hat) x_hat = ap.melspectrogram(y_hat)
figures.update( figures.update(
@ -585,7 +586,7 @@ class Wavernn(BaseVocoder):
f"test_{idx}/prediction": plot_spectrogram(x_hat.T), f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
} }
) )
audios.update({f"test_{idx}/audio", y_hat}) audios.update({f"test_{idx}/audio": y_hat})
return figures, audios return figures, audios
@staticmethod @staticmethod

View File

@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig from TTS.config.shared_configs import BaseAudioConfig
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
def run_test_train(): def run_test_train():
command = ( command = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
@ -17,6 +18,7 @@ def run_test_train():
) )
run_cli(command) run_cli(command)
config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json") config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")