mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #667 from coqui-ai/fix-test-sentences
Fix test runs and wavegrad test_run
This commit is contained in:
commit
9bb7f31f36
|
@ -764,11 +764,11 @@ class Trainer:
|
|||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if hasattr(self.eval_loader.load_test_samples):
|
||||
samples = self.eval_loader.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(samples)
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
else:
|
||||
figures, audios = self.model.test_run()
|
||||
figures, audios = self.model.test_run(self.ap)
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, figures)
|
||||
|
||||
|
@ -790,7 +790,7 @@ class Trainer:
|
|||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank <= 0:
|
||||
self.test_run()
|
||||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
|
|
|
@ -70,7 +70,7 @@ class BaseTTS(BaseModel):
|
|||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
pass
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
@ -200,7 +200,7 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
return loader
|
||||
|
||||
def test_run(self) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -212,14 +212,14 @@ class BaseTTS(BaseModel):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
|
@ -229,6 +229,6 @@ class BaseTTS(BaseModel):
|
|||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
return test_figures, test_audios
|
||||
|
|
|
@ -113,7 +113,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
""" Compute and format the mode outputs with the given alignment map"""
|
||||
"""Compute and format the mode outputs with the given alignment map"""
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
|
|
@ -2,6 +2,7 @@ import glob
|
|||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -67,7 +68,19 @@ class WaveGradDataset(Dataset):
|
|||
item = self.load_item(idx)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
def load_test_samples(self, num_samples: int) -> List[Tuple]:
|
||||
"""Return test samples.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples to return.
|
||||
|
||||
Returns:
|
||||
List[Tuple]: melspectorgram and audio.
|
||||
|
||||
Shapes:
|
||||
- melspectrogram (Tensor): :math:`[C, T]`
|
||||
- audio (Tensor): :math:`[T_audio]`
|
||||
"""
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
|
|
|
@ -31,7 +31,7 @@ def setup_model(config: Coqpit):
|
|||
|
||||
|
||||
def setup_generator(c):
|
||||
""" TODO: use config object as arguments"""
|
||||
"""TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
|
@ -94,7 +94,7 @@ def setup_generator(c):
|
|||
|
||||
|
||||
def setup_discriminator(c):
|
||||
""" TODO: use config objekt as arguments"""
|
||||
"""TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
|
|
|
@ -124,11 +124,16 @@ class Wavegrad(BaseModel):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, y_n=None):
|
||||
"""x: B x D X T"""
|
||||
"""
|
||||
Shapes:
|
||||
x: :math:`[B, C , T]`
|
||||
y_n: :math:`[B, 1, T]`
|
||||
"""
|
||||
if y_n is None:
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
|
||||
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1])
|
||||
else:
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x)
|
||||
y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0)
|
||||
y_n = y_n.type_as(x)
|
||||
sqrt_alpha_hat = self.noise_level.to(x)
|
||||
for n in range(len(self.alpha) - 1, -1, -1):
|
||||
y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0])))
|
||||
|
@ -267,8 +272,10 @@ class Wavegrad(BaseModel):
|
|||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
self.compute_noise_level(betas)
|
||||
for sample in samples:
|
||||
x = sample["input"]
|
||||
y = sample["waveform"]
|
||||
x = sample[0]
|
||||
x = x[None, :, :].to(next(self.parameters()).device)
|
||||
y = sample[1]
|
||||
y = y[None, :]
|
||||
# compute voice
|
||||
y_pred = self.inference(x)
|
||||
# compute spectrograms
|
||||
|
|
|
@ -322,7 +322,7 @@ class Wavernn(BaseVocoder):
|
|||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).type_as(mels)
|
||||
mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device))
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
|
@ -576,7 +576,8 @@ class Wavernn(BaseVocoder):
|
|||
figures = {}
|
||||
audios = {}
|
||||
for idx, sample in enumerate(samples):
|
||||
x = sample["input"]
|
||||
x = torch.FloatTensor(sample[0])
|
||||
x = x.to(next(self.parameters()).device)
|
||||
y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples)
|
||||
x_hat = ap.melspectrogram(y_hat)
|
||||
figures.update(
|
||||
|
@ -585,7 +586,7 @@ class Wavernn(BaseVocoder):
|
|||
f"test_{idx}/prediction": plot_spectrogram(x_hat.T),
|
||||
}
|
||||
)
|
||||
audios.update({f"test_{idx}/audio", y_hat})
|
||||
audios.update({f"test_{idx}/audio": y_hat})
|
||||
return figures, audios
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli
|
|||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig
|
||||
|
||||
|
||||
def run_test_train():
|
||||
command = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} "
|
||||
|
@ -17,6 +18,7 @@ def run_test_train():
|
|||
)
|
||||
run_cli(command)
|
||||
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
|
|
Loading…
Reference in New Issue