diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py
index e03cf084..fca00f0b 100644
--- a/TTS/tts/layers/losses.py
+++ b/TTS/tts/layers/losses.py
@@ -570,8 +570,8 @@ class VitsGeneratorLoss(nn.Module):
     @staticmethod
     def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
         """
-        z_p, logs_q: [b, h, t_t]
-        m_p, logs_p: [b, h, t_t]
+        z_p, logs_q: [B, C, T_de]
+        m_p, logs_p: [B, C, T_de]
         """
         z_p = z_p.float()
         logs_q = logs_q.float()
diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py
index 613e4eae..55c39c3b 100644
--- a/TTS/tts/models/vits.py
+++ b/TTS/tts/models/vits.py
@@ -735,6 +735,7 @@ class Vits(BaseTTS):
         """
         if self.args.encoder_sample_rate:
             self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
+            assert self.interpolate_factor.is_integer(), " [!] Upsampling factor must be an integer."
             self.audio_resampler = torchaudio.transforms.Resample(
                 orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
             )  # pylint: disable=W0201
@@ -803,18 +804,18 @@ class Vits(BaseTTS):
 
     def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
         # find the alignment path
-        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
+        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
         with torch.no_grad():
             o_scale = torch.exp(-2 * logs_p)
-            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1)  # [b, t, 1]
-            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
-            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
-            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1)  # [b, t, 1]
+            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True)  # [B, 1, T_en]
+            logp2 = torch.einsum("kln, klm -> knm", [-0.5 * (z_p**2), o_scale])  # [B, T_de, T_en]
+            logp3 = torch.einsum("kln, klm -> knm", [z_p, m_p * o_scale])  # [B, T_de, T_en]
+            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1], keepdim=True)  # [B, 1, T_en]
             logp = logp2 + logp3 + logp1 + logp4
-            attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()  # [b, 1, t, t']
+            attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [B, 1, T_de, T_en]
 
         # duration predictor
-        attn_durations = attn.sum(3)
+        attn_durations = attn.sum(2)
         if self.args.use_sdp:
             loss_duration = self.duration_predictor(
                 x.detach() if self.args.detach_dp_input else x,
@@ -845,7 +846,7 @@ class Vits(BaseTTS):
             # interpolate z if needed
             if self.args.interpolate_z:
                 z = torch.nn.functional.interpolate(
-                    z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="nearest"
+                    z.unsqueeze(0), scale_factor=[1, self.interpolate_factor], mode="linear"
                 ).squeeze(0)
                 # recompute the mask if needed
                 if y_lengths is not None and y_mask is not None:
@@ -890,7 +891,7 @@ class Vits(BaseTTS):
 
         Return Shapes:
             - model_outputs: :math:`[B, 1, T_wav]`
-            - alignments: :math:`[B, T_seq, T_dec]`
+            - alignments: :math:`[B, T_dec, T_seq]`
             - z: :math:`[B, C, T_dec]`
             - z_p: :math:`[B, C, T_dec]`
             - m_p: :math:`[B, C, T_dec]`
@@ -924,8 +925,8 @@ class Vits(BaseTTS):
         outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
 
         # expand prior
-        m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
-        logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
+        m_p = torch.einsum("klnm, kjm -> kjn", [attn, m_p])  # [B, 1, T_de, T_en] -> [B, C, T_de]
+        logs_p = torch.einsum("klnm, kjm -> kjn", [attn, logs_p]) # [B, 1, T_de, T_en] * [B, C, T_en] -> [B, C, T_de]
 
         # select a random feature segment for the waveform decoder
         z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
@@ -982,6 +983,7 @@ class Vits(BaseTTS):
             return aux_input["x_lengths"]
         return torch.tensor(x.shape[1:2]).to(x.device)
 
+    @torch.no_grad()
     def inference(
         self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
     ):  # pylint: disable=dangerous-default-value
@@ -1053,6 +1055,7 @@ class Vits(BaseTTS):
         outputs = {
             "model_outputs": o,
             "alignments": attn.squeeze(1),
+            "durations": w_ceil,
             "z": z,
             "z_p": z_p,
             "m_p": m_p,
@@ -1231,7 +1234,7 @@ class Vits(BaseTTS):
         audios = {f"{name_prefix}/audio": sample_voice}
 
         alignments = outputs[1]["alignments"]
-        align_img = alignments[0].data.cpu().numpy().T
+        align_img = alignments[0].data.cpu().numpy()
 
         figures.update(
             {
@@ -1390,16 +1393,16 @@ class Vits(BaseTTS):
         """Compute spectrograms on the device."""
         ac = self.config.audio
 
-        if self.args.encoder_sample_rate:
-            wav = self.audio_resampler(batch["waveform"])
-        else:
-            wav = batch["waveform"]
-
         # compute spectrograms
-        batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
+        if self.args.encoder_sample_rate:
+            # downsample audio to encoder sample rate
+            wav = self.audio_resampler(batch["waveform"])
+            batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
+        else:
+            batch["spec"] = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
 
         if self.args.encoder_sample_rate:
-            # recompute spec with high sampling rate to the loss
+            # recompute spec with vocoder sampling rate
             spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
             # remove extra stft frame
             spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py
index c2e7f561..3c8bc2c8 100644
--- a/TTS/tts/utils/helpers.py
+++ b/TTS/tts/utils/helpers.py
@@ -180,10 +180,10 @@ def maximum_path(value, mask):
 def maximum_path_cython(value, mask):
     """Cython optimised version.
     Shapes:
-        - value: :math:`[B, T_en, T_de]`
-        - mask: :math:`[B, T_en, T_de]`
+        - value: :math:`[B, T_de, T_en]`
+        - mask: :math:`[B, T_de, T_en]`
     """
-    value = value * mask
+    # value = value * mask
     device = value.device
     dtype = value.dtype
     value = value.data.cpu().numpy().astype(np.float32)
diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx
index 091fcc3a..227fc1ed 100644
--- a/TTS/tts/utils/monotonic_align/core.pyx
+++ b/TTS/tts/utils/monotonic_align/core.pyx
@@ -1,14 +1,10 @@
-import numpy as np
-
-cimport cython
-cimport numpy as np
-
+import cython
 from cython.parallel import prange
 
 
 @cython.boundscheck(False)
 @cython.wraparound(False)
-cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
+cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
   cdef int x
   cdef int y
   cdef float v_prev
@@ -21,27 +17,26 @@ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_
       if x == y:
         v_cur = max_neg_val
       else:
-        v_cur = value[x, y-1]
+        v_cur = value[y-1, x]
       if x == 0:
         if y == 0:
           v_prev = 0.
         else:
           v_prev = max_neg_val
       else:
-        v_prev = value[x-1, y-1]
-      value[x, y] = max(v_cur, v_prev) + value[x, y]
+        v_prev = value[y-1, x-1]
+      value[y, x] += max(v_prev, v_cur)
 
   for y in range(t_y - 1, -1, -1):
-    path[index, y] = 1
-    if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
+    path[y, index] = 1
+    if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
       index = index - 1
 
 
 @cython.boundscheck(False)
 @cython.wraparound(False)
-cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
-  cdef int b = values.shape[0]
-
+cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
+  cdef int b = paths.shape[0]
   cdef int i
   for i in prange(b, nogil=True):
-    maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
+    maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])