diff --git a/TTS/tts/models/forward_tts_e2e.py b/TTS/tts/models/forward_tts_e2e.py
index eaba994f..d22e301d 100644
--- a/TTS/tts/models/forward_tts_e2e.py
+++ b/TTS/tts/models/forward_tts_e2e.py
@@ -1,48 +1,297 @@
+import os
 from dataclasses import dataclass, field
 from itertools import chain
 from typing import Dict, List, Tuple, Union
+import numpy as np
+import pyworld as pw
+
 
 import torch
+import torch.distributed as dist
 from coqpit import Coqpit
 from torch import nn
 from torch.cuda.amp.autocast_mode import autocast
+from torch.utils.data import DataLoader
 from trainer.trainer_utils import get_optimizer, get_scheduler
 
-from TTS.tts.layers.losses import ForwardTTSE2ELoss, VitsDiscriminatorLoss
+from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample
+from TTS.tts.layers.losses import ForwardTTSE2eLoss, VitsDiscriminatorLoss
 from TTS.tts.layers.vits.discriminator import VitsDiscriminator
 from TTS.tts.models.base_tts import BaseTTSE2E
 from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
-from TTS.tts.models.vits import wav_to_mel
-from TTS.tts.utils.helpers import rand_segments, segment
+from TTS.tts.models.vits import load_audio, wav_to_mel
+from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0, mel_to_wav as mel_to_wav_numpy
+from TTS.tts.utils.helpers import rand_segments, segment, sequence_mask
 from TTS.tts.utils.speakers import SpeakerManager
 from TTS.tts.utils.synthesis import synthesis
 from TTS.tts.utils.text.tokenizer import TTSTokenizer
-from TTS.tts.utils.visual import plot_alignment
+from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch
 from TTS.vocoder.models.hifigan_generator import HifiganGenerator
 from TTS.vocoder.utils.generic_utils import plot_results
+from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
+
+
+def id_to_torch(aux_id, cuda=False):
+    if aux_id is not None:
+        aux_id = np.asarray(aux_id)
+        aux_id = torch.from_numpy(aux_id)
+    if cuda:
+        return aux_id.cuda()
+    return aux_id
+
+
+def embedding_to_torch(d_vector, cuda=False):
+    if d_vector is not None:
+        d_vector = np.asarray(d_vector)
+        d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
+        d_vector = d_vector.squeeze().unsqueeze(0)
+    if cuda:
+        return d_vector.cuda()
+    return d_vector
+
+
+def numpy_to_torch(np_array, dtype, cuda=False):
+    if np_array is None:
+        return None
+    tensor = torch.as_tensor(np_array, dtype=dtype)
+    if cuda:
+        return tensor.cuda()
+    return tensor
+
+
+##############################
+# DATASET
+##############################
+
+
+class ForwardTTSE2eF0Dataset(F0Dataset):
+    """Override F0Dataset to avoid the AudioProcessor."""
+
+    def __init__(
+        self,
+        audio_config: "AudioConfig",
+        samples: Union[List[List], List[Dict]],
+        verbose=False,
+        cache_path: str = None,
+        precompute_num_workers=0,
+        normalize_f0=True,
+    ):
+        self.audio_config = audio_config
+        super().__init__(
+            samples=samples,
+            ap=None,
+            verbose=verbose,
+            cache_path=cache_path,
+            precompute_num_workers=precompute_num_workers,
+            normalize_f0=normalize_f0,
+        )
+
+    @staticmethod
+    def _compute_and_save_pitch(config, wav_file, pitch_file=None):
+        wav, _ = load_audio(wav_file)
+        f0 = compute_f0(x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax)
+        # skip the last F0 value to align with the spectrogram
+        if wav.shape[1] % config.hop_length != 0:
+            f0 = f0[:-1]
+        if pitch_file:
+            np.save(pitch_file, f0)
+        return f0
+
+    def compute_or_load(self, wav_file):
+        """
+        compute pitch and return a numpy array of pitch values
+        """
+        pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
+        if not os.path.exists(pitch_file):
+            pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
+        else:
+            pitch = np.load(pitch_file)
+        return pitch.astype(np.float32)
+
+
+class ForwardTTSE2eDataset(TTSDataset):
+    def __init__(self, *args, **kwargs):
+        # don't init the default F0Dataset in TTSDataset
+        compute_f0 = kwargs.pop("compute_f0", False)
+        kwargs["compute_f0"] = False
+
+        self.audio_config = kwargs["audio_config"]
+        del kwargs["audio_config"]
+
+        super().__init__(*args, **kwargs)
+
+        self.compute_f0 = compute_f0
+        self.pad_id = self.tokenizer.characters.pad_id
+        if self.compute_f0:
+            self.f0_dataset = ForwardTTSE2eF0Dataset(
+                audio_config=self.audio_config,
+                samples=self.samples,
+                cache_path=kwargs["f0_cache_path"],
+                precompute_num_workers=kwargs["precompute_num_workers"],
+            )
+
+    def __getitem__(self, idx):
+        item = self.samples[idx]
+        raw_text = item["text"]
+
+        wav, _ = load_audio(item["audio_file"])
+        wav_filename = os.path.basename(item["audio_file"])
+
+        token_ids = self.get_token_ids(idx, item["text"])
+
+        f0 = None
+        if self.compute_f0:
+            f0 = self.get_f0(idx)["f0"]
+
+        # after phonemization the text length may change
+        # this is a shameful 🤭 hack to prevent longer phonemes
+        # TODO: find a better fix
+        if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len:
+            self.rescue_item_idx += 1
+            return self.__getitem__(self.rescue_item_idx)
+
+        return {
+            "raw_text": raw_text,
+            "token_ids": token_ids,
+            "token_len": len(token_ids),
+            "wav": wav,
+            "pitch": f0,
+            "wav_file": wav_filename,
+            "speaker_name": item["speaker_name"],
+            "language_name": item["language"],
+        }
+
+    @property
+    def lengths(self):
+        lens = []
+        for item in self.samples:
+            _, wav_file, *_ = _parse_sample(item)
+            audio_len = os.path.getsize(wav_file) / 16 * 8  # assuming 16bit audio
+            lens.append(audio_len)
+        return lens
+
+    def collate_fn(self, batch):
+        """
+        Return Shapes:
+            - tokens: :math:`[B, T]`
+            - token_lens :math:`[B]`
+            - token_rel_lens :math:`[B]`
+            - pitch :math:`[B, T]`
+            - waveform: :math:`[B, 1, T]`
+            - waveform_lens: :math:`[B]`
+            - waveform_rel_lens: :math:`[B]`
+            - speaker_names: :math:`[B]`
+            - language_names: :math:`[B]`
+            - audiofile_paths: :math:`[B]`
+            - raw_texts: :math:`[B]`
+        """
+        # convert list of dicts to dict of lists
+        B = len(batch)
+        batch = {k: [dic[k] for dic in batch] for k in batch[0]}
+
+        _, ids_sorted_decreasing = torch.sort(
+            torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True
+        )
+
+        max_text_len = max([len(x) for x in batch["token_ids"]])
+        token_lens = torch.LongTensor(batch["token_len"])
+        token_rel_lens = token_lens / token_lens.max()
+
+        wav_lens = [w.shape[1] for w in batch["wav"]]
+        wav_lens = torch.LongTensor(wav_lens)
+        wav_lens_max = torch.max(wav_lens)
+        wav_rel_lens = wav_lens / wav_lens_max
+
+        pitch_lens = [p.shape[0] for p in batch["pitch"]]
+        pitch_lens = torch.LongTensor(pitch_lens)
+        pitch_lens_max = torch.max(pitch_lens)
+
+        token_padded = torch.LongTensor(B, max_text_len)
+        wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
+        pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max)
+
+        token_padded = token_padded.zero_() + self.pad_id
+        wav_padded = wav_padded.zero_() + self.pad_id
+        pitch_padded = pitch_padded.zero_() + self.pad_id
+
+        for i in range(len(ids_sorted_decreasing)):
+            token_ids = batch["token_ids"][i]
+            token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
+
+            wav = batch["wav"][i]
+            wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
+
+            pitch = batch["pitch"][i]
+            pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch)
+
+        return {
+            "text_input": token_padded,
+            "text_lengths": token_lens,
+            "text_rel_lens": token_rel_lens,
+            "pitch": pitch_padded,
+            "waveform": wav_padded,  # (B x T)
+            "waveform_lens": wav_lens,  # (B)
+            "waveform_rel_lens": wav_rel_lens,
+            "speaker_names": batch["speaker_name"],
+            "language_names": batch["language_name"],
+            "audio_files": batch["wav_file"],
+            "raw_text": batch["raw_text"],
+        }
+
+
+##############################
+# CONFIG DEFINITIONS
+##############################
 
 
 @dataclass
-class ForwardTTSE2EArgs(ForwardTTSArgs):
+class ForwardTTSE2eAudio(Coqpit):
+    sample_rate: int = 22050
+    hop_length: int = 256
+    win_length: int = 1024
+    fft_size: int = 1024
+    mel_fmin: float = 0.0
+    mel_fmax: float = 8000
+    num_mels: int = 80
+    pitch_fmax: float = 640.0
+
+
+@dataclass
+class ForwardTTSE2eArgs(ForwardTTSArgs):
     # vocoder_config: BaseGANVocoderConfig = None
     num_chars: int = 100
     encoder_out_channels: int = 80
-    spec_segment_size: int = 32
+    spec_segment_size: int = 80
     # duration predictor
     detach_duration_predictor: bool = True
+    duration_predictor_dropout_p: float = 0.1
+    # pitch predictor
+    pitch_predictor_dropout_p: float = 0.1
     # discriminator
     init_discriminator: bool = True
     use_spectral_norm_discriminator: bool = False
     # model parameters
     detach_vocoder_input: bool = False
-    hidden_channels: int = 192
+    hidden_channels: int = 256
     encoder_type: str = "fftransformer"
     encoder_params: dict = field(
-        default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
+        default_factory=lambda: {
+            "hidden_channels_ffn": 1024,
+            "num_heads": 2,
+            "num_layers": 4,
+            "dropout_p": 0.1,
+            "kernel_size_fft": 9,
+        }
     )
     decoder_type: str = "fftransformer"
     decoder_params: dict = field(
-        default_factory=lambda: {"hidden_channels_ffn": 768, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
+        default_factory=lambda: {
+            "hidden_channels_ffn": 1024,
+            "num_heads": 2,
+            "num_layers": 4,
+            "dropout_p": 0.1,
+            "kernel_size_fft": 9,
+        }
     )
     # generator
     resblock_type_decoder: str = "1"
@@ -61,35 +310,39 @@ class ForwardTTSE2EArgs(ForwardTTSArgs):
     d_vector_dim: int = 0
 
 
-class ForwardTTSE2E(BaseTTSE2E):
+##############################
+# MODEL DEFINITION
+##############################
+
+
+class ForwardTTSE2e(BaseTTSE2E):
     """
     Model training::
         text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg
         spec --------^
 
     Examples:
-        >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EConfig
-        >>> config = ForwardTTSE2EConfig()
-        >>> model = ForwardTTSE2E(config)
+        >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig
+        >>> config = ForwardTTSE2eConfig()
+        >>> model = ForwardTTSE2e(config)
     """
 
     # pylint: disable=dangerous-default-value
     def __init__(
         self,
         config: Coqpit,
-        ap: "AudioProcessor" = None,
         tokenizer: "TTSTokenizer" = None,
         speaker_manager: SpeakerManager = None,
     ):
-        super().__init__(config, ap, tokenizer, speaker_manager)
+        super().__init__(config=config, tokenizer=tokenizer, speaker_manager=speaker_manager)
         self._set_model_args(config)
 
         self.init_multispeaker(config)
 
-        self.encoder_model = ForwardTTS(config=self.args, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager)
+        self.encoder_model = ForwardTTS(config=self.args, ap=None, tokenizer=tokenizer, speaker_manager=speaker_manager)
         # self.vocoder_model = GAN(config=self.args.vocoder_config, ap=ap)
         self.waveform_decoder = HifiganGenerator(
-            self.args.out_channels,
+            self.args.hidden_channels,
             1,
             self.args.resblock_type_decoder,
             self.args.resblock_dilation_sizes_decoder,
@@ -179,16 +432,16 @@ class ForwardTTSE2E(BaseTTSE2E):
         encoder_outputs = self.encoder_model(
             x=x, x_lengths=x_lengths, y_lengths=spec_lengths, y=spec, dr=dr, pitch=pitch, aux_input=aux_input
         )
-        spec_encoder_output = encoder_outputs["model_outputs"]
-        spec_encoder_output_slices, slice_ids = rand_segments(
-            x=spec_encoder_output.transpose(1, 2),
+        o_en_ex = encoder_outputs["o_en_ex"].transpose(1, 2)  # [B, C_en, T_max2] -> [B, T_max2, C_en]
+        o_en_ex_slices, slice_ids = rand_segments(
+            x=o_en_ex.transpose(1, 2),
             x_lengths=spec_lengths,
             segment_size=self.args.spec_segment_size,
             let_short_samples=True,
             pad_short=True,
         )
         vocoder_output = self.waveform_decoder(
-            x=spec_encoder_output_slices.detach() if self.args.detach_vocoder_input else spec_encoder_output_slices,
+            x=o_en_ex_slices.detach() if self.args.detach_vocoder_input else o_en_ex_slices,
             g=encoder_outputs["g"],
         )
         wav_seg = segment(
@@ -205,33 +458,35 @@ class ForwardTTSE2E(BaseTTSE2E):
         return model_outputs
 
     @torch.no_grad()
-    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):  # pylint: disable=unused-argument
-        encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input)
-        # vocoder_output = self.vocoder_model.model_g(x=encoder_outputs["model_outputs"].transpose(1, 2))
-        vocoder_output = self.waveform_decoder(
-            x=encoder_outputs["model_outputs"].transpose(1, 2), g=encoder_outputs["g"]
-        )
+    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
+        encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=True)
+        o_en_ex = encoder_outputs["o_en_ex"]
+        vocoder_output = self.waveform_decoder(x=o_en_ex, g=encoder_outputs["g"])
         model_outputs = {**encoder_outputs}
-        model_outputs["encoder_outputs"] = encoder_outputs["model_outputs"]
         model_outputs["model_outputs"] = vocoder_output
         return model_outputs
 
+    @torch.no_grad()
+    def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
+        encoder_outputs = self.encoder_model.inference(x=x, aux_input=aux_input, skip_decoder=False)
+        model_outputs = {**encoder_outputs}
+        return model_outputs
+
     @staticmethod
     def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False):
         """Initiate model from config
 
         Args:
-            config (ForwardTTSE2EConfig): Model config.
+            config (ForwardTTSE2eConfig): Model config.
             samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
                 Defaults to None.
         """
-        from TTS.utils.audio import AudioProcessor
+        from TTS.utils.audio.processor import AudioProcessor
 
-        ap = AudioProcessor.init_from_config(config, verbose=verbose)
         tokenizer, new_config = TTSTokenizer.init_from_config(config)
         speaker_manager = SpeakerManager.init_from_config(config, samples)
         # language_manager = LanguageManager.init_from_config(config)
-        return ForwardTTSE2E(config=new_config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager)
+        return ForwardTTSE2e(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager)
 
     def load_checkpoint(
         self, config, checkpoint_path, eval=False
@@ -248,7 +503,7 @@ class ForwardTTSE2E(BaseTTSE2E):
             token_lenghts = batch["text_lengths"]
             spec = batch["mel_input"]
             spec_lens = batch["mel_lengths"]
-            waveform = batch["waveform"].transpose(1, 2)  # [B, T, C] -> [B, C, T]
+            waveform = batch["waveform"]  # [B, T, C] -> [B, C, T]
             pitch = batch["pitch"]
             d_vectors = batch["d_vectors"]
             speaker_ids = batch["speaker_ids"]
@@ -316,6 +571,8 @@ class ForwardTTSE2E(BaseTTSE2E):
                     pitch_output=self.model_outputs_cache["pitch_avg"] if self.args.use_pitch else None,
                     pitch_target=self.model_outputs_cache["pitch_avg_gt"] if self.args.use_pitch else None,
                     input_lens=batch["text_lengths"],
+                    waveform=self.model_outputs_cache["waveform_seg"],
+                    waveform_hat=self.model_outputs_cache["model_outputs"],
                     aligner_logprob=self.model_outputs_cache["aligner_logprob"],
                     aligner_hard=self.model_outputs_cache["aligner_mas"],
                     aligner_soft=self.model_outputs_cache["aligner_soft"],
@@ -340,29 +597,53 @@ class ForwardTTSE2E(BaseTTSE2E):
     def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
         return self.train_step(batch, criterion, optimizer_idx)
 
-    @staticmethod
-    def __copy_for_logging(outputs):
-        """Change keys and copy values for logging."""
-        encoder_outputs = outputs[1].copy()
-        encoder_outputs["model_outputs"] = encoder_outputs["encoder_outputs"]
-        vocoder_outputs = outputs.copy()
-        vocoder_outputs[1]["model_outputs"] = outputs[1]["model_outputs"]
-        return encoder_outputs, vocoder_outputs
+    def _log(self, batch, outputs, name_prefix="train"):
+        figures, audios = {}, {}
 
-    def _log(self, ap, batch, outputs, name_prefix="train"):
-        encoder_outputs, vocoder_outputs = self.__copy_for_logging(outputs)
-        y_hat = vocoder_outputs[1]["model_outputs"]
-        y = vocoder_outputs[1]["waveform_seg"]
         # encoder outputs
-        encoder_figures, encoder_audios = self.encoder_model.create_logs(
-            batch=batch, outputs=encoder_outputs, ap=self.ap
-        )
+        model_outputs = outputs[1]["encoder_outputs"]
+        alignments = outputs[1]["alignments"]
+        mel_input = batch["mel_input"]
+
+        pred_spec = model_outputs[0].data.cpu().numpy()
+        gt_spec = mel_input[0].data.cpu().numpy()
+        align_img = alignments[0].data.cpu().numpy()
+
+        figures = {
+            "prediction": plot_spectrogram(pred_spec, None, output_fig=False),
+            "ground_truth": plot_spectrogram(gt_spec, None, output_fig=False),
+            "alignment": plot_alignment(align_img, output_fig=False),
+        }
+
+        # plot pitch figures
+        if self.args.use_pitch:
+            pitch_avg = abs(outputs[1]["pitch_avg_gt"][0, 0].data.cpu().numpy())
+            pitch_avg_hat = abs(outputs[1]["pitch_avg"][0, 0].data.cpu().numpy())
+            chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
+            pitch_figures = {
+                "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
+                "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
+            }
+            figures.update(pitch_figures)
+
+        # plot the attention mask computed from the predicted durations
+        if "attn_durations" in outputs[1]:
+            alignments_hat = outputs[1]["attn_durations"][0].data.cpu().numpy()
+            figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
+
+        # Sample audio
+        encoder_audio = mel_to_wav_numpy(mel=pred_spec.T, mel_basis=self.__mel_basis, **self.config.audio)
+        audios[f"{name_prefix}/encoder_audio"] = encoder_audio
+
         # vocoder outputs
-        vocoder_figures = plot_results(y_hat, y, ap, name_prefix)
+        y_hat = outputs[1]["model_outputs"]
+        y = outputs[1]["waveform_seg"]
+
+        vocoder_figures = plot_results(y_hat=y_hat, y=y, audio_config=self.config.audio, name_prefix=name_prefix)
+        figures.update(vocoder_figures)
+
         sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
-        audios = {f"{name_prefix}/real_audio": sample_voice}
-        audios[f"{name_prefix}/encoder_audio"] = encoder_audios["audio"]
-        figures = {**encoder_figures, **vocoder_figures}
+        audios[f"{name_prefix}/real_audio"] = sample_voice
         return figures, audios
 
     def train_log(
@@ -374,21 +655,20 @@ class ForwardTTSE2E(BaseTTSE2E):
         be projected onto Tensorboard.
 
         Args:
-            ap (AudioProcessor): audio processor used at training.
             batch (Dict): Model inputs used at the previous training step.
             outputs (Dict): Model outputs generated at the previous training step.
 
         Returns:
             Tuple[Dict, np.ndarray]: training plots and output waveform.
         """
-        figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/")
+        figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
         logger.train_figures(steps, figures)
-        logger.train_audios(steps, audios, self.ap.sample_rate)
+        logger.train_audios(steps, audios, self.config.audio.sample_rate)
 
     def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
-        figures, audios = self._log(ap=self.ap, batch=batch, outputs=outputs, name_prefix="vocoder/")
+        figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/")
         logger.eval_figures(steps, figures)
-        logger.eval_audios(steps, audios, self.ap.sample_rate)
+        logger.eval_audios(steps, audios, self.config.audio.sample_rate)
 
     def get_aux_input_from_test_sentences(self, sentence_info):
         if hasattr(self.config, "model_args"):
@@ -438,6 +718,78 @@ class ForwardTTSE2E(BaseTTSE2E):
             "language_name": None,
         }
 
+    def synthesize(self, text: str, speaker_id, language_id, d_vector):
+        # TODO: add language_id
+        is_cuda = next(self.parameters()).is_cuda
+
+        # convert text to sequence of token IDs
+        text_inputs = np.asarray(
+            self.tokenizer.text_to_ids(text, language=language_id),
+            dtype=np.int32,
+        )
+        # pass tensors to backend
+        if speaker_id is not None:
+            speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
+
+        if d_vector is not None:
+            d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
+
+        # if language_id is not None:
+        #     language_id = id_to_torch(language_id, cuda=is_cuda)
+
+        text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
+        text_inputs = text_inputs.unsqueeze(0)
+
+        # synthesize voice
+        outputs = self.inference(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
+
+        # collect outputs
+        wav = outputs["model_outputs"][0].data.cpu().numpy()
+        alignments = outputs["alignments"]
+        return_dict = {
+            "wav": wav,
+            "alignments": alignments,
+            "text_inputs": text_inputs,
+            "outputs": outputs,
+        }
+        return return_dict
+
+    def synthesize_with_gl(self, text: str, speaker_id, language_id, d_vector):
+        # TODO: add language_id
+        is_cuda = next(self.parameters()).is_cuda
+
+        # convert text to sequence of token IDs
+        text_inputs = np.asarray(
+            self.tokenizer.text_to_ids(text, language=language_id),
+            dtype=np.int32,
+        )
+        # pass tensors to backend
+        if speaker_id is not None:
+            speaker_id = id_to_torch(speaker_id, cuda=is_cuda)
+
+        if d_vector is not None:
+            d_vector = embedding_to_torch(d_vector, cuda=is_cuda)
+
+        # if language_id is not None:
+        #     language_id = id_to_torch(language_id, cuda=is_cuda)
+
+        text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda)
+        text_inputs = text_inputs.unsqueeze(0)
+
+        # synthesize voice
+        outputs = self.inference_spec_decoder(text_inputs, aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id})
+
+        # collect outputs
+        wav = mel_to_wav_numpy(mel=outputs["model_outputs"].cpu().numpy()[0].T, mel_basis=self.__mel_basis, **self.config.audio)
+        alignments = outputs["alignments"]
+        return_dict = {
+            "wav": wav[None, :],
+            "alignments": alignments,
+            "text_inputs": text_inputs,
+            "outputs": outputs,
+        }
+        return return_dict
+
     @torch.no_grad()
     def test_run(self, assets) -> Tuple[Dict, Dict]:
         """Generic test run for `tts` models used by `Trainer`.
@@ -453,30 +805,147 @@ class ForwardTTSE2E(BaseTTSE2E):
         test_sentences = self.config.test_sentences
         for idx, s_info in enumerate(test_sentences):
             aux_inputs = self.get_aux_input_from_test_sentences(s_info)
-            wav, alignment, _, _ = synthesis(
-                self,
+            outputs = self.synthesize(
                 aux_inputs["text"],
-                self.config,
-                "cuda" in str(next(self.parameters()).device),
                 speaker_id=aux_inputs["speaker_id"],
                 d_vector=aux_inputs["d_vector"],
-                style_wav=aux_inputs["style_wav"],
                 language_id=aux_inputs["language_id"],
-                use_griffin_lim=True,
-                do_trim_silence=False,
-            ).values()
-            test_audios["{}-audio".format(idx)] = wav
-            test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
+            )
+            outputs_gl = self.synthesize_with_gl(
+                aux_inputs["text"],
+                speaker_id=aux_inputs["speaker_id"],
+                d_vector=aux_inputs["d_vector"],
+                language_id=aux_inputs["language_id"],
+            )
+            test_audios["{}-audio".format(idx)] = outputs["wav"].T
+            test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T
+            test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False)
         return {"figures": test_figures, "audios": test_audios}
 
     def test_log(
         self, outputs: dict, logger: "Logger", assets: dict, steps: int  # pylint: disable=unused-argument
     ) -> None:
-        logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
+        logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate)
         logger.test_figures(steps, outputs["figures"])
 
+    def format_batch(self, batch: Dict) -> Dict:
+        """Compute speaker, langugage IDs and d_vector for the batch if necessary."""
+        speaker_ids = None
+        language_ids = None
+        d_vectors = None
+
+        # get numerical speaker ids from speaker names
+        if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
+            speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
+
+        if speaker_ids is not None:
+            speaker_ids = torch.LongTensor(speaker_ids)
+            batch["speaker_ids"] = speaker_ids
+
+        # get d_vectors from audio file names
+        if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
+            d_vector_mapping = self.speaker_manager.d_vectors
+            d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
+            d_vectors = torch.FloatTensor(d_vectors)
+
+        # get language ids from language names
+        if (
+            self.language_manager is not None
+            and self.language_manager.language_id_mapping
+            and self.args.use_language_embedding
+        ):
+            language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
+
+        if language_ids is not None:
+            language_ids = torch.LongTensor(language_ids)
+
+        batch["language_ids"] = language_ids
+        batch["d_vectors"] = d_vectors
+        batch["speaker_ids"] = speaker_ids
+        return batch
+
+    def format_batch_on_device(self, batch):
+        """Compute spectrograms on the device."""
+        ac = self.config.audio
+
+        # compute spectrograms
+        batch["mel_input"] = wav_to_mel(
+            batch["waveform"],
+            hop_length=ac.hop_length,
+            win_length=ac.win_length,
+            n_fft=ac.fft_size,
+            num_mels=ac.num_mels,
+            sample_rate=ac.sample_rate,
+            fmin=ac.mel_fmin,
+            fmax=ac.mel_fmax,
+            center=False,
+        )
+
+        assert (
+            batch["pitch"].shape[2] == batch["mel_input"].shape[2]
+        ), f"{batch['pitch'].shape[2]}, {batch['mel'].shape[2]}"
+        batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int()
+
+        # zero the padding frames
+        batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1)
+        batch["mel_input"] = batch["mel_input"].transpose(1, 2)
+        return batch
+
+    def get_data_loader(
+        self,
+        config: Coqpit,
+        assets: Dict,
+        is_eval: bool,
+        samples: Union[List[Dict], List[List]],
+        verbose: bool,
+        num_gpus: int,
+        rank: int = None,
+    ) -> "DataLoader":
+        if is_eval and not config.run_eval:
+            loader = None
+        else:
+            # init dataloader
+            dataset = ForwardTTSE2eDataset(
+                samples=samples,
+                audio_config=self.config.audio,
+                batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
+                min_text_len=config.min_text_len,
+                max_text_len=config.max_text_len,
+                min_audio_len=config.min_audio_len,
+                max_audio_len=config.max_audio_len,
+                phoneme_cache_path=config.phoneme_cache_path,
+                precompute_num_workers=config.precompute_num_workers,
+                compute_f0=config.compute_f0,
+                f0_cache_path=config.f0_cache_path,
+                verbose=verbose,
+                tokenizer=self.tokenizer,
+                start_by_longest=config.start_by_longest,
+            )
+
+            # wait all the DDP process to be ready
+            if num_gpus > 1:
+                dist.barrier()
+
+            # sort input sequences from short to long
+            dataset.preprocess_samples()
+
+            # get samplers
+            sampler = self.get_sampler(config, dataset, num_gpus)
+
+            loader = DataLoader(
+                dataset,
+                batch_size=config.eval_batch_size if is_eval else config.batch_size,
+                shuffle=False,  # shuffle is done in the dataset.
+                drop_last=False,  # setting this False might cause issues in AMP training.
+                sampler=sampler,
+                collate_fn=dataset.collate_fn,
+                num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
+                pin_memory=False,
+            )
+        return loader
+
     def get_criterion(self):
-        return [VitsDiscriminatorLoss(self.config), ForwardTTSE2ELoss(self.config)]
+        return [VitsDiscriminatorLoss(self.config), ForwardTTSE2eLoss(self.config)]
 
     def get_optimizer(self) -> List:
         """Initiate and return the GAN optimizers based on the config parameters.
@@ -516,3 +985,12 @@ class ForwardTTSE2E(BaseTTSE2E):
         """Schedule binary loss weight."""
         self.encoder_model.config.binary_loss_warmup_epochs = self.config.binary_loss_warmup_epochs
         self.encoder_model.on_train_step_start(trainer)
+
+    def on_init_start(self, trainer: "Trainer"):
+        self.__mel_basis = build_mel_basis(
+            sample_rate=self.config.audio.sample_rate,
+            fft_size=self.config.audio.fft_size,
+            num_mels=self.config.audio.num_mels,
+            mel_fmax=self.config.audio.mel_fmax,
+            mel_fmin=self.config.audio.mel_fmin,
+        )