diff --git a/TTS/trainer.py b/TTS/trainer.py index c56be140..903aee5f 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -764,11 +764,11 @@ class Trainer: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): - if hasattr(self.eval_loader.load_test_samples): - samples = self.eval_loader.load_test_samples(1) - figures, audios = self.model.test_run(samples) + if hasattr(self.eval_loader.dataset, "load_test_samples"): + samples = self.eval_loader.dataset.load_test_samples(1) + figures, audios = self.model.test_run(self.ap, samples, None) else: - figures, audios = self.model.test_run() + figures, audios = self.model.test_run(self.ap) self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.tb_logger.tb_test_figures(self.total_steps_done, figures) @@ -790,7 +790,7 @@ class Trainer: self.train_epoch() if self.config.run_eval: self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank < 0: + if epoch >= self.config.test_delay_epochs and self.args.rank <= 0: self.test_run() self.c_logger.print_epoch_end( epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 2ec268d6..561b76fb 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -70,7 +70,7 @@ class BaseTTS(BaseModel): def get_aux_input(self, **kwargs) -> Dict: """Prepare and return `aux_input` used by `forward()`""" - pass + return {"speaker_id": None, "style_wav": None, "d_vector": None} def format_batch(self, batch: Dict) -> Dict: """Generic batch formatting for `TTSDataset`. @@ -200,7 +200,7 @@ class BaseTTS(BaseModel): ) return loader - def test_run(self) -> Tuple[Dict, Dict]: + def test_run(self, ap) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -212,14 +212,14 @@ class BaseTTS(BaseModel): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = self._get_aux_inputs() + aux_inputs = self.get_aux_input() for idx, sen in enumerate(test_sentences): wav, alignment, model_outputs, _ = synthesis( - self.model, + self, sen, self.config, - self.use_cuda, - self.ap, + "cuda" in str(next(self.parameters()).device), + ap, speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], style_wav=aux_inputs["style_wav"], @@ -229,6 +229,6 @@ class BaseTTS(BaseModel): ).values() test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) + test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False) test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) return test_figures, test_audios diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 9f235fad..b3bceb09 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -113,7 +113,7 @@ class GlowTTS(BaseTTS): @staticmethod def compute_outputs(attn, o_mean, o_log_scale, x_mask): - """ Compute and format the mode outputs with the given alignment map""" + """Compute and format the mode outputs with the given alignment map""" y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( 1, 2 ) # [b, t', t], [b, t, d] -> [b, d, t'] diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index d99fc417..05e0fae8 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -2,6 +2,7 @@ import glob import os import random from multiprocessing import Manager +from typing import List, Tuple import numpy as np import torch @@ -67,7 +68,19 @@ class WaveGradDataset(Dataset): item = self.load_item(idx) return item - def load_test_samples(self, num_samples): + def load_test_samples(self, num_samples: int) -> List[Tuple]: + """Return test samples. + + Args: + num_samples (int): Number of samples to return. + + Returns: + List[Tuple]: melspectorgram and audio. + + Shapes: + - melspectrogram (Tensor): :math:`[C, T]` + - audio (Tensor): :math:`[T_audio]` + """ samples = [] return_segments = self.return_segments self.return_segments = False diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index 9479095e..7c209af4 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -31,7 +31,7 @@ def setup_model(config: Coqpit): def setup_generator(c): - """ TODO: use config object as arguments""" + """TODO: use config object as arguments""" print(" > Generator Model: {}".format(c.generator_model)) MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) @@ -94,7 +94,7 @@ def setup_generator(c): def setup_discriminator(c): - """ TODO: use config objekt as arguments""" + """TODO: use config objekt as arguments""" print(" > Discriminator Model: {}".format(c.discriminator_model)) if "parallel_wavegan" in c.discriminator_model: MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 03d5160e..d2983be2 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -124,11 +124,16 @@ class Wavegrad(BaseModel): @torch.no_grad() def inference(self, x, y_n=None): - """x: B x D X T""" + """ + Shapes: + x: :math:`[B, C , T]` + y_n: :math:`[B, 1, T]` + """ if y_n is None: - y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) else: - y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) + y_n = y_n.type_as(x) sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) @@ -267,8 +272,10 @@ class Wavegrad(BaseModel): betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) self.compute_noise_level(betas) for sample in samples: - x = sample["input"] - y = sample["waveform"] + x = sample[0] + x = x[None, :, :].to(next(self.parameters()).device) + y = sample[1] + y = y[None, :] # compute voice y_pred = self.inference(x) # compute spectrograms diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index a5d89d5a..c2e47120 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -322,7 +322,7 @@ class Wavernn(BaseVocoder): with torch.no_grad(): if isinstance(mels, np.ndarray): - mels = torch.FloatTensor(mels).type_as(mels) + mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device)) if mels.ndim == 2: mels = mels.unsqueeze(0) @@ -576,7 +576,8 @@ class Wavernn(BaseVocoder): figures = {} audios = {} for idx, sample in enumerate(samples): - x = sample["input"] + x = torch.FloatTensor(sample[0]) + x = x.to(next(self.parameters()).device) y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) x_hat = ap.melspectrogram(y_hat) figures.update( @@ -585,7 +586,7 @@ class Wavernn(BaseVocoder): f"test_{idx}/prediction": plot_spectrogram(x_hat.T), } ) - audios.update({f"test_{idx}/audio", y_hat}) + audios.update({f"test_{idx}/audio": y_hat}) return figures, audios @staticmethod diff --git a/tests/test_speaker_encoder_train.py b/tests/test_speaker_encoder_train.py index 4419a00f..7901fe5a 100644 --- a/tests/test_speaker_encoder_train.py +++ b/tests/test_speaker_encoder_train.py @@ -6,6 +6,7 @@ from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig + def run_test_train(): command = ( f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_encoder.py --config_path {config_path} " @@ -17,6 +18,7 @@ def run_test_train(): ) run_cli(command) + config_path = os.path.join(get_tests_output_path(), "test_speaker_encoder_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs")