diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py
new file mode 100644
index 00000000..e64944fe
--- /dev/null
+++ b/TTS/tts/configs/vits_config.py
@@ -0,0 +1,60 @@
+from dataclasses import dataclass, field
+from typing import List
+
+from TTS.tts.configs.shared_configs import BaseTTSConfig
+from TTS.tts.models.vits import VitsArgs
+
+
+@dataclass
+class VitsConfig(BaseTTSConfig):
+    """Defines parameters for VITS End2End TTS model.
+
+    Example:
+
+        >>> from TTS.tts.configs import VitsConfig
+        >>> config = VitsConfig()
+    """
+
+    model: str = "vits"
+    # model specific params
+    model_args: VitsArgs = field(default_factory=VitsArgs)
+
+    # optimizer
+    grad_clip: float = field(default_factory=lambda: [5, 5])
+    lr_gen: float = 0.0002
+    lr_disc: float = 0.0002
+    lr_scheduler_gen: str = "ExponentialLR"
+    lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
+    lr_scheduler_disc: str = "ExponentialLR"
+    lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
+    scheduler_after_epoch: bool = True
+    optimizer: str = "AdamW"
+    optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
+
+    # loss params
+    kl_loss_alpha: float = 1.0
+    disc_loss_alpha: float = 1.0
+    gen_loss_alpha: float = 1.0
+    feat_loss_alpha: float = 1.0
+    mel_loss_alpha: float = 45.0
+
+    # data loader params
+    return_wav: bool = True
+    compute_linear_spec: bool = True
+
+    # overrides
+    min_seq_len: int = 13
+    max_seq_len: int = 200
+    r: int = 1  # DO NOT CHANGE
+    add_blank: bool = True
+
+    # testing
+    test_sentences: List[str] = field(
+        default_factory=lambda: [
+            "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
+            "Be a voice, not an echo.",
+            "I'm sorry Dave. I'm afraid I can't do that.",
+            "This cake is great. It's so delicious and moist.",
+            "Prior to November 22, 1963.",
+        ]
+    )
diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py
index aaa0ba50..89326c9c 100644
--- a/TTS/tts/datasets/TTSDataset.py
+++ b/TTS/tts/datasets/TTSDataset.py
@@ -191,6 +191,7 @@ class TTSDataset(Dataset):
         else:
             text, wav_file, speaker_name = item
             attn = None
+        raw_text = text
 
         wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
 
@@ -236,6 +237,7 @@ class TTSDataset(Dataset):
             return self.load_data(self.rescue_item_idx)
 
         sample = {
+            "raw_text": raw_text,
             "text": text,
             "wav": wav,
             "attn": attn,
@@ -360,6 +362,7 @@ class TTSDataset(Dataset):
             wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
             item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
             text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
+            raw_text = [batch[idx]["raw_text"] for idx in ids_sorted_decreasing]
 
             speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
             # get pre-computed d-vectors
@@ -450,6 +453,7 @@ class TTSDataset(Dataset):
                 attns = torch.FloatTensor(attns).unsqueeze(1)
             else:
                 attns = None
+            # TODO: return dictionary
             return (
                 text,
                 text_lenghts,
@@ -463,6 +467,7 @@ class TTSDataset(Dataset):
                 speaker_ids,
                 attns,
                 wav_padded,
+                raw_text,
             )
 
         raise TypeError(
diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py
index fd607b75..4766c77d 100644
--- a/TTS/tts/layers/generic/normalization.py
+++ b/TTS/tts/layers/generic/normalization.py
@@ -28,6 +28,31 @@ class LayerNorm(nn.Module):
         return x
 
 
+class LayerNorm2(nn.Module):
+    """Layer norm for the 2nd dimension of the input using torch primitive.
+    Args:
+        channels (int): number of channels (2nd dimension) of the input.
+        eps (float): to prevent 0 division
+
+    Shapes:
+        - input: (B, C, T)
+        - output: (B, C, T)
+    """
+
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x):
+        x = x.transpose(1, -1)
+        x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.transpose(1, -1)
+
+
 class TemporalBatchNorm1d(nn.BatchNorm1d):
     """Normalize each channel separately over time and batch."""
 
diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py
index e35aeb68..2c0303be 100644
--- a/TTS/tts/layers/glow_tts/duration_predictor.py
+++ b/TTS/tts/layers/glow_tts/duration_predictor.py
@@ -18,7 +18,7 @@ class DurationPredictor(nn.Module):
         dropout_p (float): Dropout rate used after each conv layer.
     """
 
-    def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
+    def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None):
         super().__init__()
         # class arguments
         self.in_channels = in_channels
@@ -33,13 +33,18 @@ class DurationPredictor(nn.Module):
         self.norm_2 = LayerNorm(hidden_channels)
         # output layer
         self.proj = nn.Conv1d(hidden_channels, 1, 1)
+        if cond_channels is not None and cond_channels != 0:
+            self.cond = nn.Conv1d(cond_channels, in_channels, 1)
 
-    def forward(self, x, x_mask):
+    def forward(self, x, x_mask, g=None):
         """
         Shapes:
             - x: :math:`[B, C, T]`
             - x_mask: :math:`[B, 1, T]`
+            - g: :math:`[B, C, 1]`
         """
+        if g is not None:
+            x = x + self.cond(g)
         x = self.conv_1(x * x_mask)
         x = torch.relu(x)
         x = self.norm_1(x)
diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py
index 33036537..392447de 100644
--- a/TTS/tts/layers/glow_tts/glow.py
+++ b/TTS/tts/layers/glow_tts/glow.py
@@ -16,7 +16,7 @@ class ResidualConv1dLayerNormBlock(nn.Module):
     ::
 
         x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
-          |---------------> conv1d_1x1 -----------------------|
+          |---------------> conv1d_1x1 ------------------|
 
     Args:
         in_channels (int): number of input tensor channels.
diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py
index 92cace78..ba6aa1e2 100644
--- a/TTS/tts/layers/glow_tts/transformer.py
+++ b/TTS/tts/layers/glow_tts/transformer.py
@@ -4,7 +4,7 @@ import torch
 from torch import nn
 from torch.nn import functional as F
 
-from TTS.tts.layers.glow_tts.glow import LayerNorm
+from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
 
 
 class RelativePositionMultiHeadAttention(nn.Module):
@@ -271,7 +271,7 @@ class FeedForwardNetwork(nn.Module):
         dropout_p (float, optional): dropout rate. Defaults to 0.
     """
 
-    def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
+    def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False):
 
         super().__init__()
         self.in_channels = in_channels
@@ -280,17 +280,46 @@ class FeedForwardNetwork(nn.Module):
         self.kernel_size = kernel_size
         self.dropout_p = dropout_p
 
-        self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
-        self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
+        if causal:
+            self.padding = self._causal_padding
+        else:
+            self.padding = self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size)
         self.dropout = nn.Dropout(dropout_p)
 
     def forward(self, x, x_mask):
-        x = self.conv_1(x * x_mask)
+        x = self.conv_1(self.padding(x * x_mask))
         x = torch.relu(x)
         x = self.dropout(x)
-        x = self.conv_2(x * x_mask)
+        x = self.conv_2(self.padding(x * x_mask))
         return x * x_mask
 
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, self._pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, self._pad_shape(padding))
+        return x
+
+    @staticmethod
+    def _pad_shape(padding):
+        l = padding[::-1]
+        pad_shape = [item for sublist in l for item in sublist]
+        return pad_shape
+
 
 class RelativePositionTransformer(nn.Module):
     """Transformer with Relative Potional Encoding.
@@ -310,20 +339,23 @@ class RelativePositionTransformer(nn.Module):
             If default, relative encoding is disabled and it is a regular transformer.
             Defaults to None.
         input_length (int, optional): input lenght to limit position encoding. Defaults to None.
+        layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm
+            primitive. Use type "2", type "1: is for backward compat. Defaults to "1".
     """
 
     def __init__(
         self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        hidden_channels_ffn,
-        num_heads,
-        num_layers,
+        in_channels: int,
+        out_channels: int,
+        hidden_channels: int,
+        hidden_channels_ffn: int,
+        num_heads: int,
+        num_layers: int,
         kernel_size=1,
         dropout_p=0.0,
-        rel_attn_window_size=None,
-        input_length=None,
+        rel_attn_window_size: int = None,
+        input_length: int = None,
+        layer_norm_type: str = "1",
     ):
         super().__init__()
         self.hidden_channels = hidden_channels
@@ -351,7 +383,12 @@ class RelativePositionTransformer(nn.Module):
                     input_length=input_length,
                 )
             )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            if layer_norm_type == "1":
+                self.norm_layers_1.append(LayerNorm(hidden_channels))
+            elif layer_norm_type == "2":
+                self.norm_layers_1.append(LayerNorm2(hidden_channels))
+            else:
+                raise ValueError(" [!] Unknown layer norm type")
 
             if hidden_channels != out_channels and (idx + 1) == self.num_layers:
                 self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
@@ -366,7 +403,12 @@ class RelativePositionTransformer(nn.Module):
                 )
             )
 
-            self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
+            if layer_norm_type == "1":
+                self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
+            elif layer_norm_type == "2":
+                self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels))
+            else:
+                raise ValueError(" [!] Unknown layer norm type")
 
     def forward(self, x, x_mask):
         """
diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py
index 07b58974..171b0217 100644
--- a/TTS/tts/layers/losses.py
+++ b/TTS/tts/layers/losses.py
@@ -2,11 +2,13 @@ import math
 
 import numpy as np
 import torch
+from coqpit import Coqpit
 from torch import nn
 from torch.nn import functional
 
 from TTS.tts.utils.data import sequence_mask
 from TTS.tts.utils.ssim import ssim
+from TTS.utils.audio import TorchSTFT
 
 
 # pylint: disable=abstract-method
@@ -514,3 +516,142 @@ class AlignTTSLoss(nn.Module):
             + self.mdn_alpha * mdn_loss
         )
         return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
+
+
+class VitsGeneratorLoss(nn.Module):
+    def __init__(self, c: Coqpit):
+        super().__init__()
+        self.kl_loss_alpha = c.kl_loss_alpha
+        self.gen_loss_alpha = c.gen_loss_alpha
+        self.feat_loss_alpha = c.feat_loss_alpha
+        self.mel_loss_alpha = c.mel_loss_alpha
+        self.stft = TorchSTFT(
+            c.audio.fft_size,
+            c.audio.hop_length,
+            c.audio.win_length,
+            sample_rate=c.audio.sample_rate,
+            mel_fmin=c.audio.mel_fmin,
+            mel_fmax=c.audio.mel_fmax,
+            n_mels=c.audio.num_mels,
+            use_mel=True,
+            do_amp_to_db=True,
+        )
+
+    @staticmethod
+    def feature_loss(feats_real, feats_generated):
+        loss = 0
+        for dr, dg in zip(feats_real, feats_generated):
+            for rl, gl in zip(dr, dg):
+                rl = rl.float().detach()
+                gl = gl.float()
+                loss += torch.mean(torch.abs(rl - gl))
+
+        return loss * 2
+
+    @staticmethod
+    def generator_loss(scores_fake):
+        loss = 0
+        gen_losses = []
+        for dg in scores_fake:
+            dg = dg.float()
+            l = torch.mean((1 - dg) ** 2)
+            gen_losses.append(l)
+            loss += l
+
+        return loss, gen_losses
+
+    @staticmethod
+    def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
+        """
+        z_p, logs_q: [b, h, t_t]
+        m_p, logs_p: [b, h, t_t]
+        """
+        z_p = z_p.float()
+        logs_q = logs_q.float()
+        m_p = m_p.float()
+        logs_p = logs_p.float()
+        z_mask = z_mask.float()
+
+        kl = logs_p - logs_q - 0.5
+        kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+        kl = torch.sum(kl * z_mask)
+        l = kl / torch.sum(z_mask)
+        return l
+
+    def forward(
+        self,
+        waveform,
+        waveform_hat,
+        z_p,
+        logs_q,
+        m_p,
+        logs_p,
+        z_len,
+        scores_disc_fake,
+        feats_disc_fake,
+        feats_disc_real,
+    ):
+        """
+        Shapes:
+            - wavefrom: :math:`[B, 1, T]`
+            - waveform_hat: :math:`[B, 1, T]`
+            - z_p: :math:`[B, C, T]`
+            - logs_q: :math:`[B, C, T]`
+            - m_p: :math:`[B, C, T]`
+            - logs_p: :math:`[B, C, T]`
+            - z_len: :math:`[B]`
+            - scores_disc_fake[i]: :math:`[B, C]`
+            - feats_disc_fake[i][j]: :math:`[B, C, T', P]`
+            - feats_disc_real[i][j]: :math:`[B, C, T', P]`
+        """
+        loss = 0.0
+        return_dict = {}
+        z_mask = sequence_mask(z_len).float()
+        # compute mel spectrograms from the waveforms
+        mel = self.stft(waveform)
+        mel_hat = self.stft(waveform_hat)
+        # compute losses
+        loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha
+        loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
+        loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
+        loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
+        loss = loss_kl + loss_feat + loss_mel + loss_gen
+        # pass losses to the dict
+        return_dict["loss_gen"] = loss_gen
+        return_dict["loss_kl"] = loss_kl
+        return_dict["loss_feat"] = loss_feat
+        return_dict["loss_mel"] = loss_mel
+        return_dict["loss"] = loss
+        return return_dict
+
+
+class VitsDiscriminatorLoss(nn.Module):
+    def __init__(self, c: Coqpit):
+        super().__init__()
+        self.disc_loss_alpha = c.disc_loss_alpha
+
+    @staticmethod
+    def discriminator_loss(scores_real, scores_fake):
+        loss = 0
+        real_losses = []
+        fake_losses = []
+        for dr, dg in zip(scores_real, scores_fake):
+            dr = dr.float()
+            dg = dg.float()
+            real_loss = torch.mean((1 - dr) ** 2)
+            fake_loss = torch.mean(dg ** 2)
+            loss += real_loss + fake_loss
+            real_losses.append(real_loss.item())
+            fake_losses.append(fake_loss.item())
+
+        return loss, real_losses, fake_losses
+
+    def forward(self, scores_disc_real, scores_disc_fake):
+        loss = 0.0
+        return_dict = {}
+        loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
+        return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
+        loss = loss + loss_disc
+        return_dict["loss_disc"] = loss_disc
+        return_dict["loss"] = loss
+        return return_dict
diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py
new file mode 100644
index 00000000..650c9b61
--- /dev/null
+++ b/TTS/tts/layers/vits/discriminator.py
@@ -0,0 +1,77 @@
+import torch
+from torch import nn
+from torch.nn.modules.conv import Conv1d
+
+from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
+
+
+class DiscriminatorS(torch.nn.Module):
+    """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
+
+    Args:
+        use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
+    """
+
+    def __init__(self, use_spectral_norm=False):
+        super().__init__()
+        norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): input waveform.
+
+        Returns:
+            Tensor: discriminator scores.
+            List[Tensor]: list of features from the convolutiona layers.
+        """
+        feat = []
+        for l in self.convs:
+            x = l(x)
+            x = torch.nn.functional.leaky_relu(x, 0.1)
+            feat.append(x)
+        x = self.conv_post(x)
+        feat.append(x)
+        x = torch.flatten(x, 1, -1)
+        return x, feat
+
+
+class VitsDiscriminator(nn.Module):
+    """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
+
+    ::
+        waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
+               |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
+
+    Args:
+        use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
+    """
+
+    def __init__(self, use_spectral_norm=False):
+        super().__init__()
+        self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
+        self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): input waveform.
+
+        Returns:
+            List[Tensor]: discriminator scores.
+            List[List[Tensor]]: list of list of features from each layers of each discriminator.
+        """
+        scores, feats = self.mpd(x)
+        score_sd, feats_sd = self.sd(x)
+        return scores + [score_sd], feats + [feats_sd]
diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py
new file mode 100644
index 00000000..cf9d6e41
--- /dev/null
+++ b/TTS/tts/layers/vits/networks.py
@@ -0,0 +1,271 @@
+import math
+
+import torch
+from torch import nn
+
+from TTS.tts.layers.glow_tts.glow import WN
+from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
+from TTS.tts.utils.data import sequence_mask
+
+LRELU_SLOPE = 0.1
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+class TextEncoder(nn.Module):
+    def __init__(
+        self,
+        n_vocab: int,
+        out_channels: int,
+        hidden_channels: int,
+        hidden_channels_ffn: int,
+        num_heads: int,
+        num_layers: int,
+        kernel_size: int,
+        dropout_p: float,
+    ):
+        """Text Encoder for VITS model.
+
+        Args:
+            n_vocab (int): Number of characters for the embedding layer.
+            out_channels (int): Number of channels for the output.
+            hidden_channels (int): Number of channels for the hidden layers.
+            hidden_channels_ffn (int): Number of channels for the convolutional layers.
+            num_heads (int): Number of attention heads for the Transformer layers.
+            num_layers (int): Number of Transformer layers.
+            kernel_size (int): Kernel size for the FFN layers in Transformer network.
+            dropout_p (float): Dropout rate for the Transformer layers.
+        """
+        super().__init__()
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+
+        self.emb = nn.Embedding(n_vocab, hidden_channels)
+        nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
+
+        self.encoder = RelativePositionTransformer(
+            in_channels=hidden_channels,
+            out_channels=hidden_channels,
+            hidden_channels=hidden_channels,
+            hidden_channels_ffn=hidden_channels_ffn,
+            num_heads=num_heads,
+            num_layers=num_layers,
+            kernel_size=kernel_size,
+            dropout_p=dropout_p,
+            layer_norm_type="2",
+            rel_attn_window_size=4,
+        )
+
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_lengths):
+        """
+        Shapes:
+            - x: :math:`[B, T]`
+            - x_length: :math:`[B]`
+        """
+        x = self.emb(x) * math.sqrt(self.hidden_channels)  # [b, t, h]
+        x = torch.transpose(x, 1, -1)  # [b, h, t]
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+        x = self.encoder(x * x_mask, x_mask)
+        stats = self.proj(x) * x_mask
+
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        return x, m, logs, x_mask
+
+
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        num_layers,
+        dropout_p=0,
+        cond_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+        # input layer
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        # coupling layers
+        self.enc = WN(
+            hidden_channels,
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            num_layers,
+            dropout_p=dropout_p,
+            c_in_channels=cond_channels,
+        )
+        # output layer
+        # Initializing last layer to 0 makes the affine coupling layers
+        # do nothing at first.  This helps with training stability
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_mask: :math:`[B, 1, T]`
+            - g: :math:`[B, C, 1]`
+        """
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            log_scale = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(log_scale) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(log_scale, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-log_scale) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
+class ResidualCouplingBlocks(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        hidden_channels: int,
+        kernel_size: int,
+        dilation_rate: int,
+        num_layers: int,
+        num_flows=4,
+        cond_channels=0,
+    ):
+        """Redisual Coupling blocks for VITS flow layers.
+
+        Args:
+            channels (int): Number of input and output tensor channels.
+            hidden_channels (int): Number of hidden network channels.
+            kernel_size (int): Kernel size of the WaveNet layers.
+            dilation_rate (int): Dilation rate of the WaveNet layers.
+            num_layers (int): Number of the WaveNet layers.
+            num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
+            cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
+        """
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.num_layers = num_layers
+        self.num_flows = num_flows
+        self.cond_channels = cond_channels
+
+        self.flows = nn.ModuleList()
+        for _ in range(num_flows):
+            self.flows.append(
+                ResidualCouplingBlock(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    num_layers,
+                    cond_channels=cond_channels,
+                    mean_only=True,
+                )
+            )
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_mask: :math:`[B, 1, T]`
+            - g: :math:`[B, C, 1]`
+        """
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+                x = torch.flip(x, [1])
+        else:
+            for flow in reversed(self.flows):
+                x = torch.flip(x, [1])
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        hidden_channels: int,
+        kernel_size: int,
+        dilation_rate: int,
+        num_layers: int,
+        cond_channels=0,
+    ):
+        """Posterior Encoder of VITS model.
+
+        ::
+            x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
+
+        Args:
+            in_channels (int): Number of input tensor channels.
+            out_channels (int): Number of output tensor channels.
+            hidden_channels (int): Number of hidden channels.
+            kernel_size (int): Kernel size of the WaveNet convolution layers.
+            dilation_rate (int): Dilation rate of the WaveNet layers.
+            num_layers (int): Number of the WaveNet layers.
+            cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
+        """
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.num_layers = num_layers
+        self.cond_channels = cond_channels
+
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.enc = WN(
+            hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_lengths, g=None):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_lengths: :math:`[B, 1]`
+            - g: :math:`[B, C, 1]`
+        """
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+        x = self.pre(x) * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x) * x_mask
+        mean, log_scale = torch.split(stats, self.out_channels, dim=1)
+        z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
+        return z, mean, log_scale, x_mask
diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py
new file mode 100644
index 00000000..ae1edebb
--- /dev/null
+++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py
@@ -0,0 +1,276 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from TTS.tts.layers.generic.normalization import LayerNorm2
+from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform
+
+
+class DilatedDepthSeparableConv(nn.Module):
+    def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor:
+        """Dilated Depth-wise Separable Convolution module.
+
+        ::
+            x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o
+              |-------------------------------------------------------------------------------------^
+
+        Args:
+            channels ([type]): [description]
+            kernel_size ([type]): [description]
+            num_layers ([type]): [description]
+            dropout_p (float, optional): [description]. Defaults to 0.0.
+
+        Returns:
+            torch.tensor: Network output masked by the input sequence mask.
+        """
+        super().__init__()
+        self.num_layers = num_layers
+
+        self.convs_sep = nn.ModuleList()
+        self.convs_1x1 = nn.ModuleList()
+        self.norms_1 = nn.ModuleList()
+        self.norms_2 = nn.ModuleList()
+        for i in range(num_layers):
+            dilation = kernel_size ** i
+            padding = (kernel_size * dilation - dilation) // 2
+            self.convs_sep.append(
+                nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
+            )
+            self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+            self.norms_1.append(LayerNorm2(channels))
+            self.norms_2.append(LayerNorm2(channels))
+        self.dropout = nn.Dropout(dropout_p)
+
+    def forward(self, x, x_mask, g=None):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_mask: :math:`[B, 1, T]`
+        """
+        if g is not None:
+            x = x + g
+        for i in range(self.num_layers):
+            y = self.convs_sep[i](x * x_mask)
+            y = self.norms_1[i](y)
+            y = F.gelu(y)
+            y = self.convs_1x1[i](y)
+            y = self.norms_2[i](y)
+            y = F.gelu(y)
+            y = self.dropout(y)
+            x = x + y
+        return x * x_mask
+
+
+class ElementwiseAffine(nn.Module):
+    """Element-wise affine transform like no-population stats BatchNorm alternative.
+
+    Args:
+        channels (int): Number of input tensor channels.
+    """
+
+    def __init__(self, channels):
+        super().__init__()
+        self.translation = nn.Parameter(torch.zeros(channels, 1))
+        self.log_scale = nn.Parameter(torch.zeros(channels, 1))
+
+    def forward(self, x, x_mask, reverse=False, **kwargs):  # pylint: disable=unused-argument
+        if not reverse:
+            y = (x * torch.exp(self.log_scale) + self.translation) * x_mask
+            logdet = torch.sum(self.log_scale * x_mask, [1, 2])
+            return y, logdet
+        x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask
+        return x
+
+
+class ConvFlow(nn.Module):
+    """Dilated depth separable convolutional based spline flow.
+
+    Args:
+        in_channels (int): Number of input tensor channels.
+        hidden_channels (int): Number of in network channels.
+        kernel_size (int): Convolutional kernel size.
+        num_layers (int): Number of convolutional layers.
+        num_bins (int, optional): Number of spline bins. Defaults to 10.
+        tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        hidden_channels: int,
+        kernel_size: int,
+        num_layers: int,
+        num_bins=10,
+        tail_bound=5.0,
+    ):
+        super().__init__()
+        self.num_bins = num_bins
+        self.tail_bound = tail_bound
+        self.hidden_channels = hidden_channels
+        self.half_channels = in_channels // 2
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0)
+        self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1)
+        self.proj.weight.data.zero_()
+        self.proj.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0)
+        h = self.convs(h, x_mask, g=g)
+        h = self.proj(h) * x_mask
+
+        b, c, t = x0.shape
+        h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)  # [b, cx?, t] -> [b, c, t, ?]
+
+        unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels)
+        unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels)
+        unnormalized_derivatives = h[..., 2 * self.num_bins :]
+
+        x1, logabsdet = piecewise_rational_quadratic_transform(
+            x1,
+            unnormalized_widths,
+            unnormalized_heights,
+            unnormalized_derivatives,
+            inverse=reverse,
+            tails="linear",
+            tail_bound=self.tail_bound,
+        )
+
+        x = torch.cat([x0, x1], 1) * x_mask
+        logdet = torch.sum(logabsdet * x_mask, [1, 2])
+        if not reverse:
+            return x, logdet
+        return x
+
+
+class StochasticDurationPredictor(nn.Module):
+    """Stochastic duration predictor with Spline Flows.
+
+    It applies Variational Dequantization and Variationsl Data Augmentation.
+
+    Paper:
+        SDP: https://arxiv.org/pdf/2106.06103.pdf
+        Spline Flow: https://arxiv.org/abs/1906.04032
+
+    ::
+        ## Inference
+
+        x -> TextCondEncoder() -> Flow() -> dr_hat
+        noise ----------------------^
+
+        ## Training
+                                                                              |---------------------|
+        x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise
+        d -> DurCondEncoder()  -> ^                                                    |
+        |------------------------------------------------------------------------------|
+
+    Args:
+        in_channels (int): Number of input tensor channels.
+        hidden_channels (int): Number of hidden channels.
+        kernel_size (int): Kernel size of convolutional layers.
+        dropout_p (float): Dropout rate.
+        num_flows (int, optional): Number of flow blocks. Defaults to 4.
+        cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0.
+    """
+
+    def __init__(
+        self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0
+    ):
+        super().__init__()
+
+        # condition encoder text
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
+        self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
+
+        # posterior encoder
+        self.flows = nn.ModuleList()
+        self.flows.append(ElementwiseAffine(2))
+        self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
+
+        # condition encoder duration
+        self.post_pre = nn.Conv1d(1, hidden_channels, 1)
+        self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p)
+        self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1)
+
+        # flow layers
+        self.post_flows = nn.ModuleList()
+        self.post_flows.append(ElementwiseAffine(2))
+        self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)]
+
+        if cond_channels != 0 and cond_channels is not None:
+            self.cond = nn.Conv1d(cond_channels, hidden_channels, 1)
+
+    def forward(self, x, x_mask, dr=None, g=None, reverse=False, noise_scale=1.0):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_mask: :math:`[B, 1, T]`
+            - dr: :math:`[B, 1, T]`
+            - g: :math:`[B, C]`
+        """
+        # condition encoder text
+        x = self.pre(x)
+        if g is not None:
+            x = x + self.cond(g)
+        x = self.convs(x, x_mask)
+        x = self.proj(x) * x_mask
+
+        if not reverse:
+            flows = self.flows
+            assert dr is not None
+
+            # condition encoder duration
+            h = self.post_pre(dr)
+            h = self.post_convs(h, x_mask)
+            h = self.post_proj(h) * x_mask
+            noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
+            z_q = noise
+
+            # posterior encoder
+            logdet_tot_q = 0.0
+            for idx, flow in enumerate(self.post_flows):
+                z_q, logdet_q = flow(z_q, x_mask, g=(x + h))
+                logdet_tot_q = logdet_tot_q + logdet_q
+                if idx > 0:
+                    z_q = torch.flip(z_q, [1])
+
+            z_u, z_v = torch.split(z_q, [1, 1], 1)
+            u = torch.sigmoid(z_u) * x_mask
+            z0 = (dr - u) * x_mask
+
+            # posterior encoder - neg log likelihood
+            logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
+            nll_posterior_encoder = (
+                torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q
+            )
+
+            z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
+            logdet_tot = torch.sum(-z0, [1, 2])
+            z = torch.cat([z0, z_v], 1)
+
+            # flow layers
+            for idx, flow in enumerate(flows):
+                z, logdet = flow(z, x_mask, g=x, reverse=reverse)
+                logdet_tot = logdet_tot + logdet
+                if idx > 0:
+                    z = torch.flip(z, [1])
+
+            # flow layers - neg log likelihood
+            nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
+            return nll_flow_layers + nll_posterior_encoder
+
+        flows = list(reversed(self.flows))
+        flows = flows[:-2] + [flows[-1]]  # remove a useless vflow
+        z = torch.rand(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
+        for flow in flows:
+            z = torch.flip(z, [1])
+            z = flow(z, x_mask, g=x, reverse=reverse)
+
+        z0, _ = torch.split(z, [1, 1], 1)
+        logw = z0
+        return logw
diff --git a/TTS/tts/layers/vits/transforms.py b/TTS/tts/layers/vits/transforms.py
new file mode 100644
index 00000000..c1505554
--- /dev/null
+++ b/TTS/tts/layers/vits/transforms.py
@@ -0,0 +1,203 @@
+# adopted from https://github.com/bayesiains/nflows
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+def piecewise_rational_quadratic_transform(
+    inputs,
+    unnormalized_widths,
+    unnormalized_heights,
+    unnormalized_derivatives,
+    inverse=False,
+    tails=None,
+    tail_bound=1.0,
+    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+    min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+
+    if tails is None:
+        spline_fn = rational_quadratic_spline
+        spline_kwargs = {}
+    else:
+        spline_fn = unconstrained_rational_quadratic_spline
+        spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+    outputs, logabsdet = spline_fn(
+        inputs=inputs,
+        unnormalized_widths=unnormalized_widths,
+        unnormalized_heights=unnormalized_heights,
+        unnormalized_derivatives=unnormalized_derivatives,
+        inverse=inverse,
+        min_bin_width=min_bin_width,
+        min_bin_height=min_bin_height,
+        min_derivative=min_derivative,
+        **spline_kwargs,
+    )
+    return outputs, logabsdet
+
+
+def searchsorted(bin_locations, inputs, eps=1e-6):
+    bin_locations[..., -1] += eps
+    return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
+
+
+def unconstrained_rational_quadratic_spline(
+    inputs,
+    unnormalized_widths,
+    unnormalized_heights,
+    unnormalized_derivatives,
+    inverse=False,
+    tails="linear",
+    tail_bound=1.0,
+    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+    min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+    inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+    outside_interval_mask = ~inside_interval_mask
+
+    outputs = torch.zeros_like(inputs)
+    logabsdet = torch.zeros_like(inputs)
+
+    if tails == "linear":
+        unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+        constant = np.log(np.exp(1 - min_derivative) - 1)
+        unnormalized_derivatives[..., 0] = constant
+        unnormalized_derivatives[..., -1] = constant
+
+        outputs[outside_interval_mask] = inputs[outside_interval_mask]
+        logabsdet[outside_interval_mask] = 0
+    else:
+        raise RuntimeError("{} tails are not implemented.".format(tails))
+
+    outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
+        inputs=inputs[inside_interval_mask],
+        unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+        unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+        unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+        inverse=inverse,
+        left=-tail_bound,
+        right=tail_bound,
+        bottom=-tail_bound,
+        top=tail_bound,
+        min_bin_width=min_bin_width,
+        min_bin_height=min_bin_height,
+        min_derivative=min_derivative,
+    )
+
+    return outputs, logabsdet
+
+
+def rational_quadratic_spline(
+    inputs,
+    unnormalized_widths,
+    unnormalized_heights,
+    unnormalized_derivatives,
+    inverse=False,
+    left=0.0,
+    right=1.0,
+    bottom=0.0,
+    top=1.0,
+    min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+    min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+    min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+    if torch.min(inputs) < left or torch.max(inputs) > right:
+        raise ValueError("Input to a transform is not within its domain")
+
+    num_bins = unnormalized_widths.shape[-1]
+
+    if min_bin_width * num_bins > 1.0:
+        raise ValueError("Minimal bin width too large for the number of bins")
+    if min_bin_height * num_bins > 1.0:
+        raise ValueError("Minimal bin height too large for the number of bins")
+
+    widths = F.softmax(unnormalized_widths, dim=-1)
+    widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+    cumwidths = torch.cumsum(widths, dim=-1)
+    cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+    cumwidths = (right - left) * cumwidths + left
+    cumwidths[..., 0] = left
+    cumwidths[..., -1] = right
+    widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+    derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+    heights = F.softmax(unnormalized_heights, dim=-1)
+    heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+    cumheights = torch.cumsum(heights, dim=-1)
+    cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+    cumheights = (top - bottom) * cumheights + bottom
+    cumheights[..., 0] = bottom
+    cumheights[..., -1] = top
+    heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+    if inverse:
+        bin_idx = searchsorted(cumheights, inputs)[..., None]
+    else:
+        bin_idx = searchsorted(cumwidths, inputs)[..., None]
+
+    input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+    input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+    input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+    delta = heights / widths
+    input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+    input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+    input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+    input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+    if inverse:
+        a = (inputs - input_cumheights) * (
+            input_derivatives + input_derivatives_plus_one - 2 * input_delta
+        ) + input_heights * (input_delta - input_derivatives)
+        b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+            input_derivatives + input_derivatives_plus_one - 2 * input_delta
+        )
+        c = -input_delta * (inputs - input_cumheights)
+
+        discriminant = b.pow(2) - 4 * a * c
+        assert (discriminant >= 0).all()
+
+        root = (2 * c) / (-b - torch.sqrt(discriminant))
+        outputs = root * input_bin_widths + input_cumwidths
+
+        theta_one_minus_theta = root * (1 - root)
+        denominator = input_delta + (
+            (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
+        )
+        derivative_numerator = input_delta.pow(2) * (
+            input_derivatives_plus_one * root.pow(2)
+            + 2 * input_delta * theta_one_minus_theta
+            + input_derivatives * (1 - root).pow(2)
+        )
+        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+        return outputs, -logabsdet
+    else:
+        theta = (inputs - input_cumwidths) / input_bin_widths
+        theta_one_minus_theta = theta * (1 - theta)
+
+        numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
+        denominator = input_delta + (
+            (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
+        )
+        outputs = input_cumheights + numerator / denominator
+
+        derivative_numerator = input_delta.pow(2) * (
+            input_derivatives_plus_one * theta.pow(2)
+            + 2 * input_delta * theta_one_minus_theta
+            + input_derivatives * (1 - theta).pow(2)
+        )
+        logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+        return outputs, logabsdet
diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py
index e441cc05..cd4c33d0 100644
--- a/TTS/tts/models/base_tts.py
+++ b/TTS/tts/models/base_tts.py
@@ -212,13 +212,22 @@ class BaseTTS(BaseModel):
                 else None,
             )
 
-            if (
-                config.use_phonemes
-                and config.compute_input_seq_cache
-                and not os.path.exists(dataset.phoneme_cache_path)
-            ):
-                # precompute phonemes to have a better estimate of sequence lengths.
-                dataset.compute_input_seq(config.num_loader_workers)
+            if config.use_phonemes and config.compute_input_seq_cache:
+                if hasattr(self, "eval_data_items") and is_eval:
+                    dataset.items = self.eval_data_items
+                elif hasattr(self, "train_data_items") and not is_eval:
+                    dataset.items = self.train_data_items
+                else:
+                    # precompute phonemes to have a better estimate of sequence lengths.
+                    dataset.compute_input_seq(config.num_loader_workers)
+
+                    # TODO: find a more efficient solution
+                    # cheap hack - store items in the model state to avoid recomputing when reinit the dataset
+                    if is_eval:
+                        self.eval_data_items = dataset.items
+                    else:
+                        self.train_data_items = dataset.items
+
             dataset.sort_items()
 
             sampler = DistributedSampler(dataset) if num_gpus > 1 else None
diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py
new file mode 100644
index 00000000..9a2eec89
--- /dev/null
+++ b/TTS/tts/models/vits.py
@@ -0,0 +1,758 @@
+from dataclasses import dataclass, field
+from typing import Dict, List, Tuple
+
+import torch
+from coqpit import Coqpit
+from torch import nn
+from torch.cuda.amp.autocast_mode import autocast
+
+from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
+from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
+from TTS.tts.layers.vits.discriminator import VitsDiscriminator
+from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
+from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
+
+# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
+from TTS.tts.models.base_tts import BaseTTS
+from TTS.tts.utils.data import sequence_mask
+from TTS.tts.utils.speakers import get_speaker_manager
+from TTS.tts.utils.synthesis import synthesis
+from TTS.tts.utils.visual import plot_alignment
+from TTS.utils.audio import AudioProcessor
+from TTS.utils.trainer_utils import get_optimizer, get_scheduler
+from TTS.vocoder.models.hifigan_generator import HifiganGenerator
+from TTS.vocoder.utils.generic_utils import plot_results
+
+
+def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4):
+    """Segment each sample in a batch based on the provided segment indices"""
+    segments = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        index_start = segment_indices[i]
+        index_end = index_start + segment_size
+        segments[i] = x[i, :, index_start:index_end]
+    return segments
+
+
+def rand_segment(x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4):
+    """Create random segments based on the input lengths."""
+    B, _, T = x.size()
+    if x_lengths is None:
+        x_lengths = T
+    max_idxs = x_lengths - segment_size + 1
+    assert all(max_idxs > 0), " [!] At least one sample is shorter than the segment size."
+    ids_str = (torch.rand([B]).type_as(x) * max_idxs).long()
+    ret = segment(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+@dataclass
+class VitsArgs(Coqpit):
+    """VITS model arguments.
+
+    Args:
+
+        num_chars (int):
+            Number of characters in the vocabulary. Defaults to 100.
+
+        out_channels (int):
+            Number of output channels. Defaults to 513.
+
+        spec_segment_size (int):
+            Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
+
+        hidden_channels (int):
+            Number of hidden channels of the model. Defaults to 192.
+
+        hidden_channels_ffn_text_encoder (int):
+            Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
+
+        num_heads_text_encoder (int):
+            Number of attention heads of the text encoder transformer. Defaults to 2.
+
+        num_layers_text_encoder (int):
+            Number of transformer layers in the text encoder. Defaults to 6.
+
+        kernel_size_text_encoder (int):
+            Kernel size of the text encoder transformer FFN layers. Defaults to 3.
+
+        dropout_p_text_encoder (float):
+            Dropout rate of the text encoder. Defaults to 0.1.
+
+        kernel_size_posterior_encoder (int):
+            Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
+
+        dilatation_posterior_encoder (int):
+            Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
+
+        num_layers_posterior_encoder (int):
+            Number of posterior encoder's WaveNet layers. Defaults to 16.
+
+        kernel_size_flow (int):
+            Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
+
+        dilatation_flow (int):
+            Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
+
+        num_layers_flow (int):
+            Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
+
+        resblock_type_decoder (str):
+            Type of the residual block in the decoder network. Defaults to "1".
+
+        resblock_kernel_sizes_decoder (List[int]):
+            Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
+
+        resblock_dilation_sizes_decoder (List[List[int]]):
+            Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
+
+        upsample_rates_decoder (List[int]):
+            Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
+            values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
+
+        upsample_initial_channel_decoder (int):
+            Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
+
+        upsample_kernel_sizes_decoder (List[int]):
+            Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
+
+        use_sdp (int):
+            Use Stochastic Duration Predictor. Defaults to True.
+
+        noise_scale (float):
+            Noise scale used for the sample noise tensor in training. Defaults to 1.0.
+
+        inference_noise_scale (float):
+            Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
+
+        length_scale (int):
+            Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
+
+        noise_scale_dp (float):
+            Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
+
+        inference_noise_scale_dp (float):
+            Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
+
+        max_inference_len (int):
+            Maximum inference length to limit the memory use. Defaults to None.
+
+        init_discriminator (bool):
+            Initialize the disciminator network if set True. Set False for inference. Defaults to True.
+
+        use_spectral_norm_disriminator (bool):
+            Use spectral normalization over weight norm in the discriminator. Defaults to False.
+
+        use_speaker_embedding (bool):
+            Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
+
+        num_speakers (int):
+            Number of speakers for the speaker embedding layer. Defaults to 0.
+
+        speakers_file (str):
+            Path to the speaker mapping file for the Speaker Manager. Defaults to None.
+
+        speaker_embedding_channels (int):
+            Number of speaker embedding channels. Defaults to 256.
+
+        use_d_vector_file (bool):
+            Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
+
+        d_vector_dim (int):
+            Number of d-vector channels. Defaults to 0.
+
+        detach_dp_input (bool):
+            Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
+    """
+
+    num_chars: int = 100
+    out_channels: int = 513
+    spec_segment_size: int = 32
+    hidden_channels: int = 192
+    hidden_channels_ffn_text_encoder: int = 768
+    num_heads_text_encoder: int = 2
+    num_layers_text_encoder: int = 6
+    kernel_size_text_encoder: int = 3
+    dropout_p_text_encoder: int = 0.1
+    kernel_size_posterior_encoder: int = 5
+    dilation_rate_posterior_encoder: int = 1
+    num_layers_posterior_encoder: int = 16
+    kernel_size_flow: int = 5
+    dilation_rate_flow: int = 1
+    num_layers_flow: int = 4
+    resblock_type_decoder: int = "1"
+    resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
+    resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
+    upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
+    upsample_initial_channel_decoder: int = 512
+    upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
+    use_sdp: int = True
+    noise_scale: float = 1.0
+    inference_noise_scale: float = 0.667
+    length_scale: int = 1
+    noise_scale_dp: float = 1.0
+    inference_noise_scale_dp: float = 0.8
+    max_inference_len: int = None
+    init_discriminator: bool = True
+    use_spectral_norm_disriminator: bool = False
+    use_speaker_embedding: bool = False
+    num_speakers: int = 0
+    speakers_file: str = None
+    speaker_embedding_channels: int = 256
+    use_d_vector_file: bool = False
+    d_vector_dim: int = 0
+    detach_dp_input: bool = True
+
+
+class Vits(BaseTTS):
+    """VITS TTS model
+
+    Paper::
+        https://arxiv.org/pdf/2106.06103.pdf
+
+    Paper Abstract::
+        Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
+        sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
+        In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
+        current two-stage models. Our method adopts variational inference augmented with normalizing flows and
+        an adversarial training process, which improves the expressive power of generative modeling. We also propose a
+        stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
+        uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
+        natural one-to-many relationship in which a text input can be spoken in multiple ways
+        with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
+        on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
+        available TTS systems and achieves a MOS comparable to ground truth.
+
+    Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
+
+    Examples:
+        >>> from TTS.tts.configs import VitsConfig
+        >>> from TTS.tts.models.vits import Vits
+        >>> config = VitsConfig()
+        >>> model = Vits(config)
+    """
+
+    # pylint: disable=dangerous-default-value
+
+    def __init__(self, config: Coqpit):
+
+        super().__init__()
+
+        self.END2END = True
+
+        if config.__class__.__name__ == "VitsConfig":
+            # loading from VitsConfig
+            if "num_chars" not in config:
+                _, self.config, num_chars = self.get_characters(config)
+                config.model_args.num_chars = num_chars
+            else:
+                self.config = config
+                config.model_args.num_chars = config.num_chars
+            args = self.config.model_args
+        elif isinstance(config, VitsArgs):
+            # loading from VitsArgs
+            self.config = config
+            args = config
+        else:
+            raise ValueError("config must be either a VitsConfig or VitsArgs")
+
+        self.args = args
+
+        self.init_multispeaker(config)
+
+        self.length_scale = args.length_scale
+        self.noise_scale = args.noise_scale
+        self.inference_noise_scale = args.inference_noise_scale
+        self.inference_noise_scale_dp = args.inference_noise_scale_dp
+        self.noise_scale_dp = args.noise_scale_dp
+        self.max_inference_len = args.max_inference_len
+        self.spec_segment_size = args.spec_segment_size
+
+        self.text_encoder = TextEncoder(
+            args.num_chars,
+            args.hidden_channels,
+            args.hidden_channels,
+            args.hidden_channels_ffn_text_encoder,
+            args.num_heads_text_encoder,
+            args.num_layers_text_encoder,
+            args.kernel_size_text_encoder,
+            args.dropout_p_text_encoder,
+        )
+
+        self.posterior_encoder = PosteriorEncoder(
+            args.out_channels,
+            args.hidden_channels,
+            args.hidden_channels,
+            kernel_size=args.kernel_size_posterior_encoder,
+            dilation_rate=args.dilation_rate_posterior_encoder,
+            num_layers=args.num_layers_posterior_encoder,
+            cond_channels=self.embedded_speaker_dim,
+        )
+
+        self.flow = ResidualCouplingBlocks(
+            args.hidden_channels,
+            args.hidden_channels,
+            kernel_size=args.kernel_size_flow,
+            dilation_rate=args.dilation_rate_flow,
+            num_layers=args.num_layers_flow,
+            cond_channels=self.embedded_speaker_dim,
+        )
+
+        if args.use_sdp:
+            self.duration_predictor = StochasticDurationPredictor(
+                args.hidden_channels, 192, 3, 0.5, 4, cond_channels=self.embedded_speaker_dim
+            )
+        else:
+            self.duration_predictor = DurationPredictor(
+                args.hidden_channels, 256, 3, 0.5, cond_channels=self.embedded_speaker_dim
+            )
+
+        self.waveform_decoder = HifiganGenerator(
+            args.hidden_channels,
+            1,
+            args.resblock_type_decoder,
+            args.resblock_dilation_sizes_decoder,
+            args.resblock_kernel_sizes_decoder,
+            args.upsample_kernel_sizes_decoder,
+            args.upsample_initial_channel_decoder,
+            args.upsample_rates_decoder,
+            inference_padding=0,
+            cond_channels=self.embedded_speaker_dim,
+            conv_pre_weight_norm=False,
+            conv_post_weight_norm=False,
+            conv_post_bias=False,
+        )
+
+        if args.init_discriminator:
+            self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator)
+
+    def init_multispeaker(self, config: Coqpit, data: List = None):
+        """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
+        or with external `d_vectors` computed from a speaker encoder model.
+
+        If you need a different behaviour, override this function for your model.
+
+        Args:
+            config (Coqpit): Model configuration.
+            data (List, optional): Dataset items to infer number of speakers. Defaults to None.
+        """
+        if hasattr(config, "model_args"):
+            config = config.model_args
+        self.embedded_speaker_dim = 0
+        # init speaker manager
+        self.speaker_manager = get_speaker_manager(config, data=data)
+        if config.num_speakers > 0 and self.speaker_manager.num_speakers == 0:
+            self.speaker_manager.num_speakers = config.num_speakers
+        self.num_speakers = self.speaker_manager.num_speakers
+        # init speaker embedding layer
+        if config.use_speaker_embedding and not config.use_d_vector_file:
+            self.embedded_speaker_dim = config.speaker_embedding_channels
+            self.emb_g = nn.Embedding(config.num_speakers, config.speaker_embedding_channels)
+        # init d-vector usage
+        if config.use_d_vector_file:
+            self.embedded_speaker_dim = config.d_vector_dim
+
+    @staticmethod
+    def _set_cond_input(aux_input: Dict):
+        """Set the speaker conditioning input based on the multi-speaker mode."""
+        sid, g = None, None
+        if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
+            sid = aux_input["speaker_ids"]
+            if sid.ndim == 0:
+                sid = sid.unsqueeze_(0)
+        if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
+            g = aux_input["d_vectors"]
+        return sid, g
+
+    def forward(
+        self,
+        x: torch.tensor,
+        x_lengths: torch.tensor,
+        y: torch.tensor,
+        y_lengths: torch.tensor,
+        aux_input={"d_vectors": None, "speaker_ids": None},
+    ) -> Dict:
+        """Forward pass of the model.
+
+        Args:
+            x (torch.tensor): Batch of input character sequence IDs.
+            x_lengths (torch.tensor): Batch of input character sequence lengths.
+            y (torch.tensor): Batch of input spectrograms.
+            y_lengths (torch.tensor): Batch of input spectrogram lengths.
+            aux_input (dict, optional): Auxiliary inputs for multi-speaker training. Defaults to {"d_vectors": None, "speaker_ids": None}.
+
+        Returns:
+            Dict: model outputs keyed by the output name.
+
+        Shapes:
+            - x: :math:`[B, T_seq]`
+            - x_lengths: :math:`[B]`
+            - y: :math:`[B, C, T_spec]`
+            - y_lengths: :math:`[B]`
+            - d_vectors: :math:`[B, C, 1]`
+            - speaker_ids: :math:`[B]`
+        """
+        outputs = {}
+        sid, g = self._set_cond_input(aux_input)
+        x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
+
+        # speaker embedding
+        if self.num_speakers > 1 and sid is not None:
+            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]
+
+        # posterior encoder
+        z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
+
+        # flow layers
+        z_p = self.flow(z, y_mask, g=g)
+
+        # find the alignment path
+        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
+        with torch.no_grad():
+            o_scale = torch.exp(-2 * logs_p)
+            # logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1)  # [b, t, 1]
+            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
+            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
+            # logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1)  # [b, t, 1]
+            logp = logp2 + logp3
+            attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
+
+        # duration predictor
+        attn_durations = attn.sum(3)
+        if self.args.use_sdp:
+            nll_duration = self.duration_predictor(
+                x.detach() if self.args.detach_dp_input else x,
+                x_mask,
+                attn_durations,
+                g=g.detach() if self.args.detach_dp_input and g is not None else g,
+            )
+            nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
+            outputs["nll_duration"] = nll_duration
+        else:
+            attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
+            log_durations = self.duration_predictor(
+                x.detach() if self.args.detach_dp_input else x,
+                x_mask,
+                g=g.detach() if self.args.detach_dp_input and g is not None else g,
+            )
+            loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
+            outputs["loss_duration"] = loss_duration
+
+        # expand prior
+        m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
+        logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
+
+        # select a random feature segment for the waveform decoder
+        z_slice, slice_ids = rand_segment(z, y_lengths, self.spec_segment_size)
+        o = self.waveform_decoder(z_slice, g=g)
+        outputs.update(
+            {
+                "model_outputs": o,
+                "alignments": attn.squeeze(1),
+                "slice_ids": slice_ids,
+                "z": z,
+                "z_p": z_p,
+                "m_p": m_p,
+                "logs_p": logs_p,
+                "m_q": m_q,
+                "logs_q": logs_q,
+            }
+        )
+        return outputs
+
+    def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
+        """
+        Shapes:
+            - x: :math:`[B, T_seq]`
+            - d_vectors: :math:`[B, C, 1]`
+            - speaker_ids: :math:`[B]`
+        """
+        sid, g = self._set_cond_input(aux_input)
+        x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
+
+        x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths)
+
+        if self.num_speakers > 0 and sid:
+            g = self.emb_g(sid).unsqueeze(-1)
+
+        if self.args.use_sdp:
+            logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp)
+        else:
+            logw = self.duration_predictor(x, x_mask, g=g)
+
+        w = torch.exp(logw) * x_mask * self.length_scale
+        w_ceil = torch.ceil(w)
+        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
+        y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
+        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+        attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
+
+        m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
+        logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
+
+        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
+        z = self.flow(z_p, y_mask, g=g, reverse=True)
+        o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
+
+        outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
+        return outputs
+
+    def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
+        """TODO: create an end-point for voice conversion"""
+        assert self.num_speakers > 0, "num_speakers have to be larger than 0."
+        g_src = self.emb_g(sid_src).unsqueeze(-1)
+        g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
+        z, _, _, y_mask = self.enc_q(y, y_lengths, g=g_src)
+        z_p = self.flow(z, y_mask, g=g_src)
+        z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
+        o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
+        return o_hat, y_mask, (z, z_p, z_hat)
+
+    def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
+        """Perform a single training step. Run the model forward pass and compute losses.
+
+        Args:
+            batch (Dict): Input tensors.
+            criterion (nn.Module): Loss layer designed for the model.
+            optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
+
+        Returns:
+            Tuple[Dict, Dict]: Model ouputs and computed losses.
+        """
+        # pylint: disable=attribute-defined-outside-init
+        if optimizer_idx not in [0, 1]:
+            raise ValueError(" [!] Unexpected `optimizer_idx`.")
+
+        if optimizer_idx == 0:
+            text_input = batch["text_input"]
+            text_lengths = batch["text_lengths"]
+            mel_lengths = batch["mel_lengths"]
+            linear_input = batch["linear_input"]
+            d_vectors = batch["d_vectors"]
+            speaker_ids = batch["speaker_ids"]
+            waveform = batch["waveform"]
+
+            # generator pass
+            outputs = self.forward(
+                text_input,
+                text_lengths,
+                linear_input.transpose(1, 2),
+                mel_lengths,
+                aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
+            )
+
+            # cache tensors for the discriminator
+            self.y_disc_cache = None
+            self.wav_seg_disc_cache = None
+            self.y_disc_cache = outputs["model_outputs"]
+            wav_seg = segment(
+                waveform.transpose(1, 2),
+                outputs["slice_ids"] * self.config.audio.hop_length,
+                self.args.spec_segment_size * self.config.audio.hop_length,
+            )
+            self.wav_seg_disc_cache = wav_seg
+            outputs["waveform_seg"] = wav_seg
+
+            # compute discriminator scores and features
+            outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
+            _, outputs["feats_disc_real"] = self.disc(wav_seg)
+
+            # compute losses
+            with autocast(enabled=False):  # use float32 for the criterion
+                loss_dict = criterion[optimizer_idx](
+                    waveform_hat=outputs["model_outputs"].float(),
+                    waveform=wav_seg.float(),
+                    z_p=outputs["z_p"].float(),
+                    logs_q=outputs["logs_q"].float(),
+                    m_p=outputs["m_p"].float(),
+                    logs_p=outputs["logs_p"].float(),
+                    z_len=mel_lengths,
+                    scores_disc_fake=outputs["scores_disc_fake"],
+                    feats_disc_fake=outputs["feats_disc_fake"],
+                    feats_disc_real=outputs["feats_disc_real"],
+                )
+
+            # handle the duration loss
+            if self.args.use_sdp:
+                loss_dict["nll_duration"] = outputs["nll_duration"]
+                loss_dict["loss"] += outputs["nll_duration"]
+            else:
+                loss_dict["loss_duration"] = outputs["loss_duration"]
+                loss_dict["loss"] += outputs["nll_duration"]
+
+        elif optimizer_idx == 1:
+            # discriminator pass
+            outputs = {}
+
+            # compute scores and features
+            outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
+            outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
+
+            # compute loss
+            with autocast(enabled=False):  # use float32 for the criterion
+                loss_dict = criterion[optimizer_idx](
+                    outputs["scores_disc_real"],
+                    outputs["scores_disc_fake"],
+                )
+        return outputs, loss_dict
+
+    def train_log(
+        self, ap: AudioProcessor, batch: Dict, outputs: List, name_prefix="train"
+    ):  # pylint: disable=no-self-use
+        """Create visualizations and waveform examples.
+
+        For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
+        be projected onto Tensorboard.
+
+        Args:
+            ap (AudioProcessor): audio processor used at training.
+            batch (Dict): Model inputs used at the previous training step.
+            outputs (Dict): Model outputs generated at the previoud training step.
+
+        Returns:
+            Tuple[Dict, np.ndarray]: training plots and output waveform.
+        """
+        y_hat = outputs[0]["model_outputs"]
+        y = outputs[0]["waveform_seg"]
+        figures = plot_results(y_hat, y, ap, name_prefix)
+        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
+        audios = {f"{name_prefix}/audio": sample_voice}
+
+        alignments = outputs[0]["alignments"]
+        align_img = alignments[0].data.cpu().numpy().T
+
+        figures.update(
+            {
+                "alignment": plot_alignment(align_img, output_fig=False),
+            }
+        )
+
+        return figures, audios
+
+    @torch.no_grad()
+    def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
+        return self.train_step(batch, criterion, optimizer_idx)
+
+    def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
+        return self.train_log(ap, batch, outputs, "eval")
+
+    @torch.no_grad()
+    def test_run(self, ap) -> Tuple[Dict, Dict]:
+        """Generic test run for `tts` models used by `Trainer`.
+
+        You can override this for a different behaviour.
+
+        Returns:
+            Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
+        """
+        print(" | > Synthesizing test sentences.")
+        test_audios = {}
+        test_figures = {}
+        test_sentences = self.config.test_sentences
+        aux_inputs = self.get_aux_input()
+        for idx, sen in enumerate(test_sentences):
+            wav, alignment, _, _ = synthesis(
+                self,
+                sen,
+                self.config,
+                "cuda" in str(next(self.parameters()).device),
+                ap,
+                speaker_id=aux_inputs["speaker_id"],
+                d_vector=aux_inputs["d_vector"],
+                style_wav=aux_inputs["style_wav"],
+                enable_eos_bos_chars=self.config.enable_eos_bos_chars,
+                use_griffin_lim=True,
+                do_trim_silence=False,
+            ).values()
+
+            test_audios["{}-audio".format(idx)] = wav
+            test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
+        return test_figures, test_audios
+
+    def get_optimizer(self) -> List:
+        """Initiate and return the GAN optimizers based on the config parameters.
+
+        It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
+
+        Returns:
+            List: optimizers.
+        """
+        self.disc.requires_grad_(False)
+        gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
+        self.disc.requires_grad_(True)
+        optimizer1 = get_optimizer(
+            self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
+        )
+        optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
+        return [optimizer1, optimizer2]
+
+    def get_lr(self) -> List:
+        """Set the initial learning rates for each optimizer.
+
+        Returns:
+            List: learning rates for each optimizer.
+        """
+        return [self.config.lr_gen, self.config.lr_disc]
+
+    def get_scheduler(self, optimizer) -> List:
+        """Set the schedulers for each optimizer.
+
+        Args:
+            optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
+
+        Returns:
+            List: Schedulers, one for each optimizer.
+        """
+        scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
+        scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
+        return [scheduler1, scheduler2]
+
+    def get_criterion(self):
+        """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
+        `train_step()`"""
+        from TTS.tts.layers.losses import (  # pylint: disable=import-outside-toplevel
+            VitsDiscriminatorLoss,
+            VitsGeneratorLoss,
+        )
+
+        return [VitsGeneratorLoss(self.config), VitsDiscriminatorLoss(self.config)]
+
+    @staticmethod
+    def make_symbols(config):
+        """Create a custom arrangement of symbols used by the model. The output list of symbols propagate along the
+        whole training and inference steps."""
+        _pad = config.characters["pad"]
+        _punctuations = config.characters["punctuations"]
+        _letters = config.characters["characters"]
+        _letters_ipa = config.characters["phonemes"]
+        symbols = [_pad] + list(_punctuations) + list(_letters)
+        if config.use_phonemes:
+            symbols += list(_letters_ipa)
+        return symbols
+
+    @staticmethod
+    def get_characters(config: Coqpit):
+        if config.characters is not None:
+            symbols = Vits.make_symbols(config)
+        else:
+            from TTS.tts.utils.text.symbols import (  # pylint: disable=import-outside-toplevel
+                parse_symbols,
+                phonemes,
+                symbols,
+            )
+
+            config.characters = parse_symbols()
+            if config.use_phonemes:
+                symbols = phonemes
+        num_chars = len(symbols) + getattr(config, "add_blank", False)
+        return symbols, config, num_chars
+
+    def load_checkpoint(
+        self, config, checkpoint_path, eval=False
+    ):  # pylint: disable=unused-argument, redefined-builtin
+        """Load the model checkpoint and setup for training or inference"""
+        state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+        self.load_state_dict(state["model"])
+        if eval:
+            self.eval()
+            assert not self.training
diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py
index 48f69374..d4345b64 100644
--- a/TTS/tts/utils/text/__init__.py
+++ b/TTS/tts/utils/text/__init__.py
@@ -81,7 +81,6 @@ def text2phone(text, language, use_espeak_phonemes=False):
         # Fix a few phonemes
         ph = ph.translate(GRUUT_TRANS_TABLE)
 
-        # print(" > Phonemes: {}".format(ph))
         return ph
 
     raise ValueError(f" [!] Language {language} is not supported for phonemization.")
@@ -116,6 +115,7 @@ def phoneme_to_sequence(
     use_espeak_phonemes: bool = False,
 ) -> List[int]:
     """Converts a string of phonemes to a sequence of IDs.
+    If `custom_symbols` is provided, it will override the default symbols.
 
     Args:
       text (str): string to convert to a sequence
@@ -132,12 +132,11 @@ def phoneme_to_sequence(
     # pylint: disable=global-statement
     global _phonemes_to_id, _phonemes
 
-    if tp:
-        _, _phonemes = make_symbols(**tp)
-        _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
-    elif custom_symbols is not None:
+    if custom_symbols is not None:
         _phonemes = custom_symbols
-        _phonemes_to_id = {s: i for i, s in enumerate(custom_symbols)}
+    elif tp:
+        _, _phonemes = make_symbols(**tp)
+    _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
 
     sequence = []
     clean_text = _clean_text(text, cleaner_names)
@@ -155,16 +154,19 @@ def phoneme_to_sequence(
     return sequence
 
 
-def sequence_to_phoneme(sequence, tp=None, add_blank=False):
+def sequence_to_phoneme(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List["str"] = None):
     # pylint: disable=global-statement
     """Converts a sequence of IDs back to a string"""
     global _id_to_phonemes, _phonemes
     if add_blank:
         sequence = list(filter(lambda x: x != len(_phonemes), sequence))
     result = ""
-    if tp:
+
+    if custom_symbols is not None:
+        _phonemes = custom_symbols
+    elif tp:
         _, _phonemes = make_symbols(**tp)
-        _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
+    _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
 
     for symbol_id in sequence:
         if symbol_id in _id_to_phonemes:
@@ -177,6 +179,7 @@ def text_to_sequence(
     text: str, cleaner_names: List[str], custom_symbols: List[str] = None, tp: Dict = None, add_blank: bool = False
 ) -> List[int]:
     """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+    If `custom_symbols` is provided, it will override the default symbols.
 
     Args:
       text (str): string to convert to a sequence
@@ -189,12 +192,12 @@ def text_to_sequence(
     """
     # pylint: disable=global-statement
     global _symbol_to_id, _symbols
-    if tp:
-        _symbols, _ = make_symbols(**tp)
-        _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
-    elif custom_symbols is not None:
+
+    if custom_symbols is not None:
         _symbols = custom_symbols
-        _symbol_to_id = {s: i for i, s in enumerate(custom_symbols)}
+    elif tp:
+        _symbols, _ = make_symbols(**tp)
+    _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
 
     sequence = []
 
@@ -213,16 +216,18 @@ def text_to_sequence(
     return sequence
 
 
-def sequence_to_text(sequence, tp=None, add_blank=False):
+def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_symbols: List[str] = None):
     """Converts a sequence of IDs back to a string"""
     # pylint: disable=global-statement
     global _id_to_symbol, _symbols
     if add_blank:
         sequence = list(filter(lambda x: x != len(_symbols), sequence))
 
-    if tp:
+    if custom_symbols is not None:
+        _symbols = custom_symbols
+    elif tp:
         _symbols, _ = make_symbols(**tp)
-        _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
+    _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
 
     result = ""
     for symbol_id in sequence:
diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py
index 1f21369f..40d82365 100644
--- a/TTS/utils/audio.py
+++ b/TTS/utils/audio.py
@@ -96,10 +96,12 @@ class TorchSTFT(nn.Module):  # pylint: disable=abstract-method
         )
         self.mel_basis = torch.from_numpy(mel_basis).float()
 
-    def _amp_to_db(self, x, spec_gain=1.0):
+    @staticmethod
+    def _amp_to_db(x, spec_gain=1.0):
         return torch.log(torch.clamp(x, min=1e-5) * spec_gain)
 
-    def _db_to_amp(self, x, spec_gain=1.0):
+    @staticmethod
+    def _db_to_amp(x, spec_gain=1.0):
         return torch.exp(x) / spec_gain
 
 
diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py
index a20c17b4..ca5eaf40 100644
--- a/TTS/vocoder/models/hifigan_discriminator.py
+++ b/TTS/vocoder/models/hifigan_discriminator.py
@@ -33,10 +33,10 @@ class DiscriminatorP(torch.nn.Module):
         norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
         self.convs = nn.ModuleList(
             [
-                norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
-                norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
-                norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
-                norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+                norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+                norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+                norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
+                norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
                 norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
             ]
         )
@@ -81,15 +81,15 @@ class MultiPeriodDiscriminator(torch.nn.Module):
     Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
     """
 
-    def __init__(self):
+    def __init__(self, use_spectral_norm=False):
         super().__init__()
         self.discriminators = nn.ModuleList(
             [
-                DiscriminatorP(2),
-                DiscriminatorP(3),
-                DiscriminatorP(5),
-                DiscriminatorP(7),
-                DiscriminatorP(11),
+                DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
+                DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
+                DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
+                DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
+                DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
             ]
         )
 
@@ -99,7 +99,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
             x (Tensor): input waveform.
 
         Returns:
-            [List[Tensor]]: list of scores from each discriminator.
+        [List[Tensor]]: list of scores from each discriminator.
             [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
 
         Shapes:
diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py
index 2260b781..a1e16150 100644
--- a/TTS/vocoder/models/hifigan_generator.py
+++ b/TTS/vocoder/models/hifigan_generator.py
@@ -170,6 +170,10 @@ class HifiganGenerator(torch.nn.Module):
         upsample_initial_channel,
         upsample_factors,
         inference_padding=5,
+        cond_channels=0,
+        conv_pre_weight_norm=True,
+        conv_post_weight_norm=True,
+        conv_post_bias=True,
     ):
         r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
 
@@ -218,12 +222,21 @@ class HifiganGenerator(torch.nn.Module):
             for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                 self.resblocks.append(resblock(ch, k, d))
         # post convolution layer
-        self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
+        self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
+        if cond_channels > 0:
+            self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
 
-    def forward(self, x):
+        if not conv_pre_weight_norm:
+            remove_weight_norm(self.conv_pre)
+
+        if not conv_post_weight_norm:
+            remove_weight_norm(self.conv_post)
+
+    def forward(self, x, g=None):
         """
         Args:
-            x (Tensor): conditioning input tensor.
+            x (Tensor): feature input tensor.
+            g (Tensor): global conditioning input tensor.
 
         Returns:
             Tensor: output waveform.
@@ -233,6 +246,8 @@ class HifiganGenerator(torch.nn.Module):
             Tensor: [B, 1, T]
         """
         o = self.conv_pre(x)
+        if hasattr(self, "cond_layer"):
+            o = o + self.cond_layer(g)
         for i in range(self.num_upsamples):
             o = F.leaky_relu(o, LRELU_SLOPE)
             o = self.ups[i](o)
diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py
index 0a6bd4c8..8a66c537 100644
--- a/TTS/vocoder/models/univnet_generator.py
+++ b/TTS/vocoder/models/univnet_generator.py
@@ -1,3 +1,5 @@
+from typing import List
+
 import numpy as np
 import torch
 import torch.nn.functional as F
@@ -10,18 +12,35 @@ LRELU_SLOPE = 0.1
 class UnivnetGenerator(torch.nn.Module):
     def __init__(
         self,
-        in_channels,
-        out_channels,
-        hidden_channels,
-        cond_channels,
-        upsample_factors,
-        lvc_layers_each_block,
-        lvc_kernel_size,
-        kpnet_hidden_channels,
-        kpnet_conv_size,
-        dropout,
+        in_channels: int,
+        out_channels: int,
+        hidden_channels: int,
+        cond_channels: int,
+        upsample_factors: List[int],
+        lvc_layers_each_block: int,
+        lvc_kernel_size: int,
+        kpnet_hidden_channels: int,
+        kpnet_conv_size: int,
+        dropout: float,
         use_weight_norm=True,
     ):
+        """Univnet Generator network.
+
+        Paper: https://arxiv.org/pdf/2106.07889.pdf
+
+        Args:
+            in_channels (int): Number of input tensor channels.
+            out_channels (int): Number of channels of the output tensor.
+            hidden_channels (int): Number of hidden network channels.
+            cond_channels (int): Number of channels of the conditioning tensors.
+            upsample_factors (List[int]): List of uplsample factors for the upsampling layers.
+            lvc_layers_each_block (int): Number of LVC layers in each block.
+            lvc_kernel_size (int): Kernel size of the LVC layers.
+            kpnet_hidden_channels (int): Number of hidden channels in the key-point network.
+            kpnet_conv_size (int): Number of convolution channels in the key-point network.
+            dropout (float): Dropout rate.
+            use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True.
+        """
 
         super().__init__()
         self.in_channels = in_channels
diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py
index eeabbea5..63a0af44 100644
--- a/TTS/vocoder/utils/generic_utils.py
+++ b/TTS/vocoder/utils/generic_utils.py
@@ -1,8 +1,11 @@
+from typing import Dict
+
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
 
 from TTS.tts.utils.visual import plot_spectrogram
+from TTS.utils.audio import AudioProcessor
 
 
 def interpolate_vocoder_input(scale_factor, spec):
@@ -26,12 +29,24 @@ def interpolate_vocoder_input(scale_factor, spec):
     return spec
 
 
-def plot_results(y_hat, y, ap, name_prefix):
-    """Plot vocoder model results"""
+def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict:
+    """Plot the predicted and the real waveform and their spectrograms.
+
+    Args:
+        y_hat (torch.tensor): Predicted waveform.
+        y (torch.tensor): Real waveform.
+        ap (AudioProcessor): Audio processor used to process the waveform.
+        name_prefix (str, optional): Name prefix used to name the figures. Defaults to None.
+
+    Returns:
+        Dict: output figures keyed by the name of the figures.
+    """ """Plot vocoder model results"""
+    if name_prefix is None:
+        name_prefix = ""
 
     # select an instance from batch
-    y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
-    y = y[0].squeeze(0).detach().cpu().numpy()
+    y_hat = y_hat[0].squeeze().detach().cpu().numpy()
+    y = y[0].squeeze().detach().cpu().numpy()
 
     spec_fake = ap.melspectrogram(y_hat).T
     spec_real = ap.melspectrogram(y).T
diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py
new file mode 100644
index 00000000..db9d2fc1
--- /dev/null
+++ b/tests/tts_tests/test_vits_train.py
@@ -0,0 +1,54 @@
+import glob
+import os
+import shutil
+
+from tests import get_device_id, get_tests_output_path, run_cli
+from TTS.tts.configs import VitsConfig
+
+config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
+output_path = os.path.join(get_tests_output_path(), "train_outputs")
+
+
+config = VitsConfig(
+    batch_size=2,
+    eval_batch_size=2,
+    num_loader_workers=0,
+    num_eval_loader_workers=0,
+    text_cleaner="english_cleaners",
+    use_phonemes=True,
+    use_espeak_phonemes=True,
+    phoneme_language="en-us",
+    phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
+    run_eval=True,
+    test_delay_epochs=-1,
+    epochs=1,
+    print_step=1,
+    print_eval=True,
+    test_sentences=[
+        "Be a voice, not an echo.",
+    ],
+)
+config.audio.do_trim_silence = True
+config.audio.trim_db = 60
+config.save_json(config_path)
+
+# train the model for one epoch
+command_train = (
+    f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
+    f"--coqpit.output_path {output_path} "
+    "--coqpit.datasets.0.name ljspeech "
+    "--coqpit.datasets.0.meta_file_train metadata.csv "
+    "--coqpit.datasets.0.meta_file_val metadata.csv "
+    "--coqpit.datasets.0.path tests/data/ljspeech "
+    "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
+    "--coqpit.test_delay_epochs 0"
+)
+run_cli(command_train)
+
+# Find latest folder
+continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
+
+# restore the model and continue training for one more epoch
+command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
+run_cli(command_train)
+shutil.rmtree(continue_path)