From 231c69b12e60e2a6aa921a4a32c215e6a93a76f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 19 Apr 2022 09:19:07 +0200 Subject: [PATCH] Remove AP from FastPitchE2e --- TTS/tts/models/forward_tts_e2e.py | 616 ++++++++++++++++++++++++++---- 1 file changed, 547 insertions(+), 69 deletions(-) diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py index eaba994f..d22e301d 100644 --- a/TTS/tts/models/forward_tts_e2e.py +++ b/TTS/tts/models/forward_tts_e2e.py @@ -1,48 +1,297 @@ +import os from dataclasses import dataclass, field from itertools import chain from typing import Dict, List, Tuple, Union +import numpy as np +import pyworld as pw + import torch +import torch.distributed as dist from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast +from torch.utils.data import DataLoader from trainer.trainer_utils import get_optimizer, get_scheduler -from TTS.tts.layers.losses import ForwardTTSE2ELoss, VitsDiscriminatorLoss +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.losses import ForwardTTSE2eLoss, VitsDiscriminatorLoss from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs -from TTS.tts.models.vits import wav_to_mel -from TTS.tts.utils.helpers import rand_segments, segment +from TTS.tts.models.vits import load_audio, wav_to_mel +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy +from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +############################## +# DATASET +############################## + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid the AudioProcessor.""" + + def __init__( + self, + audio_config: "AudioConfig", + samples: Union[List[List], List[Dict]], + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.audio_config = audio_config + super().__init__( + samples=samples, + ap=None, + verbose=verbose, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + @staticmethod + def _compute_and_save_pitch(config, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % config.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + + self.audio_config = kwargs["audio_config"] + del kwargs["audio_config"] + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + audio_config=self.audio_config, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + pitch_padded = pitch_padded.zero_() + self.pad_id + + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# CONFIG DEFINITIONS +############################## @dataclass -class ForwardTTSE2EArgs(ForwardTTSArgs): +class ForwardTTSE2eAudio(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 80 + pitch_fmax: float = 640.0 + + +@dataclass +class ForwardTTSE2eArgs(ForwardTTSArgs): # vocoder_config: BaseGANVocoderConfig = None num_chars: int = 100 encoder_out_channels: int = 80 - spec_segment_size: int = 32 + spec_segment_size: int = 80 # duration predictor detach_duration_predictor: bool = True + duration_predictor_dropout_p: float = 0.1 + # pitch predictor + pitch_predictor_dropout_p: float = 0.1 # discriminator init_discriminator: bool = True use_spectral_norm_discriminator: bool = False # model parameters detach_vocoder_input: bool = False - hidden_channels: int = 192 + hidden_channels: int = 256 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 2, + "num_layers": 4, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 2, + "num_layers": 4, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) # generator resblock_type_decoder: str = "1" @@ -61,35 +310,39 @@ class ForwardTTSE2EArgs(ForwardTTSArgs): d_vector_dim: int = 0 -class ForwardTTSE2E(BaseTTSE2E): +############################## +# MODEL DEFINITION +############################## + + +class ForwardTTSE2e(BaseTTSE2E): """ Model training:: text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg spec --------^ Examples: - >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EConfig - >>> config = ForwardTTSE2EConfig() - >>> model = ForwardTTSE2E(config) + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) """ # pylint: disable=dangerous-default-value def __init__( self, config: Coqpit, - ap: "AudioProcessor" = None, tokenizer: "TTSTokenizer" = None, speaker_manager: SpeakerManager = None, ): - super().__init__(config, ap, tokenizer, speaker_manager) + super().__init__(config=config, tokenizer=tokenizer, speaker_manager=speaker_manager) self._set_model_args(config) self.init_multispeaker(config) - self.encoder_model = ForwardTTS(config=self.args, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + self.encoder_model = ForwardTTS(config=self.args, ap=None, tokenizer=tokenizer, speaker_manager=speaker_manager) # self.vocoder_model = GAN(config=self.args.vocoder_config, ap=ap) self.waveform_decoder = HifiganGenerator( - self.args.out_channels, + self.args.hidden_channels, 1, self.args.resblock_type_decoder, self.args.resblock_dilation_sizes_decoder, @@ -179,16 +432,16 @@ class ForwardTTSE2E(BaseTTSE2E): encoder_outputs = self.encoder_model( x=x, x_lengths=x_lengths, y_lengths=spec_lengths, y=spec, dr=dr, pitch=pitch, aux_input=aux_input ) - spec_encoder_output = encoder_outputs["model_outputs"] - spec_encoder_output_slices, slice_ids = rand_segments( - x=spec_encoder_output.transpose(1, 2), + o_en_ex = encoder_outputs["o_en_ex"].transpose(1, 2) # [B, C_en, T_max2] -> [B, T_max2, C_en] + o_en_ex_slices, slice_ids = rand_segments( + x=o_en_ex.transpose(1, 2), x_lengths=spec_lengths, segment_size=self.args.spec_segment_size, let_short_samples=True, pad_short=True, ) vocoder_output = self.waveform_decoder( - x=spec_encoder_output_slices.detach() if self.args.detach_vocoder_input else spec_encoder_output_slices, + x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices, g=encoder_outputs["g"], ) wav_seg = segment( @@ -205,33 +458,35 @@ class ForwardTTSE2E(BaseTTSE2E): return model_outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument - encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input) - # vocoder_output = self.vocoder_model.model_g(x=encoder_outputs["model_outputs"].transpose(1, 2)) - vocoder_output = self.waveform_decoder( - x=encoder_outputs["model_outputs"].transpose(1, 2), g=encoder_outputs["g"] - ) + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True) + o_en_ex = encoder_outputs["o_en_ex"] + vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"]) model_outputs = {**encoder_outputs} - model_outputs["encoder_outputs"] = encoder_outputs["model_outputs"] model_outputs["model_outputs"] = vocoder_output return model_outputs + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=False) + model_outputs = {**encoder_outputs} + return model_outputs + @staticmethod def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False): """Initiate model from config Args: - config (ForwardTTSE2EConfig): Model config. + config (ForwardTTSE2eConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. """ - from TTS.utils.audio import AudioProcessor + from TTS.utils.audio.processor import AudioProcessor - ap = AudioProcessor.init_from_config(config, verbose=verbose) tokenizer, new_config = TTSTokenizer.init_from_config(config) speaker_manager = SpeakerManager.init_from_config(config, samples) # language_manager = LanguageManager.init_from_config(config) - return ForwardTTSE2E(config=new_config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager) def load_checkpoint( self, config, checkpoint_path, eval=False @@ -248,7 +503,7 @@ class ForwardTTSE2E(BaseTTSE2E): token_lenghts = batch["text_lengths"] spec = batch["mel_input"] spec_lens = batch["mel_lengths"] - waveform = batch["waveform"].transpose(1, 2) # [B, T, C] -> [B, C, T] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] pitch = batch["pitch"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] @@ -316,6 +571,8 @@ class ForwardTTSE2E(BaseTTSE2E): pitch_output=self.model_outputs_cache["pitch_avg"] if self.args.use_pitch else None, pitch_target=self.model_outputs_cache["pitch_avg_gt"] if self.args.use_pitch else None, input_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], aligner_logprob=self.model_outputs_cache["aligner_logprob"], aligner_hard=self.model_outputs_cache["aligner_mas"], aligner_soft=self.model_outputs_cache["aligner_soft"], @@ -340,29 +597,53 @@ class ForwardTTSE2E(BaseTTSE2E): def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): return self.train_step(batch, criterion, optimizer_idx) - @staticmethod - def __copy_for_logging(outputs): - """Change keys and copy values for logging.""" - encoder_outputs = outputs[1].copy() - encoder_outputs["model_outputs"] = encoder_outputs["encoder_outputs"] - vocoder_outputs = outputs.copy() - vocoder_outputs[1]["model_outputs"] = outputs[1]["model_outputs"] - return encoder_outputs, vocoder_outputs + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} - def _log(self, ap, batch, outputs, name_prefix="train"): - encoder_outputs, vocoder_outputs = self.__copy_for_logging(outputs) - y_hat = vocoder_outputs[1]["model_outputs"] - y = vocoder_outputs[1]["waveform_seg"] # encoder outputs - encoder_figures, encoder_audios = self.encoder_model.create_logs( - batch=batch, outputs=encoder_outputs, ap=self.ap - ) + model_outputs = outputs[1]["encoder_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs[1]["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs[1]: + alignments_hat = outputs[1]["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + # vocoder outputs - vocoder_figures = plot_results(y_hat, y, ap, name_prefix) + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, audio_config=self.config.audio, name_prefix=name_prefix) + figures.update(vocoder_figures) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - audios = {f"{name_prefix}/real_audio": sample_voice} - audios[f"{name_prefix}/encoder_audio"] = encoder_audios["audio"] - figures = {**encoder_figures, **vocoder_figures} + audios[f"{name_prefix}/real_audio"] = sample_voice return figures, audios def train_log( @@ -374,21 +655,20 @@ class ForwardTTSE2E(BaseTTSE2E): be projected onto Tensorboard. Args: - ap (AudioProcessor): audio processor used at training. batch (Dict): Model inputs used at the previous training step. outputs (Dict): Model outputs generated at the previous training step. Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/") + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") logger.train_figures(steps, figures) - logger.train_audios(steps, audios, self.ap.sample_rate) + logger.train_audios(steps, audios, self.config.audio.sample_rate) def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: - figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/") + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") logger.eval_figures(steps, figures) - logger.eval_audios(steps, audios, self.ap.sample_rate) + logger.eval_audios(steps, audios, self.config.audio.sample_rate) def get_aux_input_from_test_sentences(self, sentence_info): if hasattr(self.config, "model_args"): @@ -438,6 +718,78 @@ class ForwardTTSE2E(BaseTTSE2E): "language_name": None, } + def synthesize(self, text: str, speaker_id, language_id, d_vector): + # TODO: add language_id + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + # if language_id is not None: + # language_id = id_to_torch(language_id, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, language_id, d_vector): + # TODO: add language_id + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + # if language_id is not None: + # language_id = id_to_torch(language_id, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}) + + # collect outputs + wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. @@ -453,30 +805,147 @@ class ForwardTTSE2E(BaseTTSE2E): test_sentences = self.config.test_sentences for idx, s_info in enumerate(test_sentences): aux_inputs = self.get_aux_input_from_test_sentences(s_info) - wav, alignment, _, _ = synthesis( - self, + outputs = self.synthesize( aux_inputs["text"], - self.config, - "cuda" in str(next(self.parameters()).device), speaker_id=aux_inputs["speaker_id"], d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], language_id=aux_inputs["language_id"], - use_griffin_lim=True, - do_trim_silence=False, - ).values() - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + language_id=aux_inputs["language_id"], + ) + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return {"figures": test_figures, "audios": test_audios} def test_log( self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument ) -> None: - logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) logger.test_figures(steps, outputs["figures"]) + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.d_vectors + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if ( + self.language_manager is not None + and self.language_manager.language_id_mapping + and self.args.use_language_embedding + ): + language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + assert ( + batch["pitch"].shape[2] == batch["mel_input"].shape[2] + ), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}" + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + batch["mel_input"] = batch["mel_input"].transpose(1, 2) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + audio_config=self.config.audio, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + def get_criterion(self): - return [VitsDiscriminatorLoss(self.config), ForwardTTSE2ELoss(self.config)] + return [VitsDiscriminatorLoss(self.config), ForwardTTSE2eLoss(self.config)] def get_optimizer(self) -> List: """Initiate and return the GAN optimizers based on the config parameters. @@ -516,3 +985,12 @@ class ForwardTTSE2E(BaseTTSE2E): """Schedule binary loss weight.""" self.encoder_model.config.binary_loss_warmup_epochs = self.config.binary_loss_warmup_epochs self.encoder_model.on_train_step_start(trainer) + + def on_init_start(self, trainer: "Trainer"): + self.__mel_basis = build_mel_basis( + sample_rate=self.config.audio.sample_rate, + fft_size=self.config.audio.fft_size, + num_mels=self.config.audio.num_mels, + mel_fmax=self.config.audio.mel_fmax, + mel_fmin=self.config.audio.mel_fmin, + )