diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py
index a1273f7f..46e1812d 100644
--- a/TTS/tts/models/forward_tts.py
+++ b/TTS/tts/models/forward_tts.py
@@ -213,18 +213,20 @@ class ForwardTTS(BaseTTS):
         )
 
         self.duration_predictor = DurationPredictor(
-            self.args.hidden_channels + self.embedded_speaker_dim,
+            self.args.hidden_channels,
             self.args.duration_predictor_hidden_channels,
             self.args.duration_predictor_kernel_size,
             self.args.duration_predictor_dropout_p,
+            cond_channels=self.embedded_speaker_dim,
         )
 
         if self.args.use_pitch:
             self.pitch_predictor = DurationPredictor(
-                self.args.hidden_channels + self.embedded_speaker_dim,
+                self.args.hidden_channels,
                 self.args.pitch_predictor_hidden_channels,
                 self.args.pitch_predictor_kernel_size,
                 self.args.pitch_predictor_dropout_p,
+                cond_channels=self.embedded_speaker_dim,
             )
             self.pitch_emb = nn.Conv1d(
                 1,
@@ -245,24 +247,54 @@ class ForwardTTS(BaseTTS):
             config (Coqpit): Model configuration.
         """
         self.embedded_speaker_dim = 0
-        # init speaker manager
-        if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
-            raise ValueError(
-                " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
-            )
-        # set number of speakers
-        if self.speaker_manager is not None:
+        self.num_speakers = self.args.num_speakers
+        self.audio_transform = None
+
+        if self.speaker_manager:
             self.num_speakers = self.speaker_manager.num_speakers
-        # init d-vector embedding
-        if config.use_d_vector_file:
-            self.embedded_speaker_dim = config.d_vector_dim
-            if self.args.d_vector_dim != self.args.hidden_channels:
-                self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
-        # init speaker embedding layer
-        if config.use_speaker_embedding and not config.use_d_vector_file:
-            print(" > Init speaker_embedding layer.")
-            self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
-            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
+
+        if self.args.use_speaker_embedding:
+            self._init_speaker_embedding()
+
+        if self.args.use_d_vector_file:
+            self._init_d_vector()
+
+    def _init_speaker_embedding(self):
+        # pylint: disable=attribute-defined-outside-init
+        if self.num_speakers > 0:
+            print(" > initialization of speaker-embedding layers.")
+            self.embedded_speaker_dim = self.args.speaker_embedding_channels
+            self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
+
+    def _init_d_vector(self):
+        # pylint: disable=attribute-defined-outside-init
+        if hasattr(self, "emb_g"):
+            raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
+        self.embedded_speaker_dim = self.args.d_vector_dim
+
+    @staticmethod
+    def _set_cond_input(aux_input: Dict):
+        """Set the speaker conditioning input based on the multi-speaker mode."""
+        sid, g, lid = None, None, None
+        if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
+            sid = aux_input["speaker_ids"]
+            if sid.ndim == 0:
+                sid = sid.unsqueeze_(0)
+        if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
+            g = torch.nn.functional.normalize(aux_input["d_vectors"]).unsqueeze(-1)
+            if g.ndim == 2:
+                g = g.unsqueeze_(0)
+
+        if "language_ids" in aux_input and aux_input["language_ids"] is not None:
+            lid = aux_input["language_ids"]
+            if lid.ndim == 0:
+                lid = lid.unsqueeze_(0)
+
+        return sid, g, lid
+
+    def get_aux_input(self, aux_input: Dict):
+        sid, g, lid = self._set_cond_input(aux_input)
+        return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
 
     @staticmethod
     def generate_attn(dr, x_mask, y_mask=None):
@@ -362,10 +394,7 @@ class ForwardTTS(BaseTTS):
         # encoder pass
         o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
         # speaker conditioning
-        # TODO: try different ways of conditioning
-        if g is not None:
-            o_en = o_en + g
-        return o_en, x_mask, g, x_emb
+        return x_emb, x_mask, g, o_en
 
     def _forward_decoder(
         self,
@@ -395,7 +424,7 @@ class ForwardTTS(BaseTTS):
         """
         y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
         # expand o_en with durations
-        o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
+        o_en_ex, attn = self.expand_encoder_outputs(en=o_en, dr=dr, x_mask=x_mask, y_mask=y_mask)
         # positional encoding
         if hasattr(self, "pos_encoder"):
             o_en_ex = self.pos_encoder(o_en_ex, y_mask)
@@ -409,6 +438,7 @@ class ForwardTTS(BaseTTS):
         x_mask: torch.IntTensor,
         pitch: torch.FloatTensor = None,
         dr: torch.IntTensor = None,
+        g: torch.FloatTensor = None,
     ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
         """Pitch predictor forward pass.
 
@@ -421,6 +451,7 @@ class ForwardTTS(BaseTTS):
             x_mask (torch.IntTensor): Input sequence mask.
             pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
             dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
+            g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
 
         Returns:
             Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
@@ -431,7 +462,7 @@ class ForwardTTS(BaseTTS):
             - pitch: :math:`(B, 1, T_{de})`
             - dr: :math:`(B, T_{en})`
         """
-        o_pitch = self.pitch_predictor(o_en, x_mask)
+        o_pitch = self.pitch_predictor(o_en, x_mask, g=g)
         if pitch is not None:
             avg_pitch = average_over_durations(pitch, dr)
             o_pitch_emb = self.pitch_emb(avg_pitch)
@@ -466,19 +497,19 @@ class ForwardTTS(BaseTTS):
             - x_mask: :math:`[B, 1, T_en]`
             - y_mask: :math:`[B, 1, T_de]`
 
-            - o_alignment_dur: :math:`[B, T_en]`
-            - alignment_soft: :math:`[B, T_en, T_de]`
-            - alignment_logprob: :math:`[B, 1, T_de, T_en]`
-            - alignment_mas: :math:`[B, T_en, T_de]`
+            - aligner_durations: :math:`[B, T_en]`
+            - aligner_soft: :math:`[B, T_en, T_de]`
+            - aligner_logprob: :math:`[B, 1, T_de, T_en]`
+            - aligner_mas: :math:`[B, T_en, T_de]`
         """
         attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
-        alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
-        alignment_mas = maximum_path(
-            alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
+        aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
+        aligner_mas = maximum_path(
+            aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
         )
-        o_alignment_dur = torch.sum(alignment_mas, -1).int()
-        alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
-        return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
+        aligner_durations = torch.sum(aligner_mas, -1).int()
+        aligner_soft = aligner_soft.squeeze(1).transpose(1, 2)
+        return aligner_durations, aligner_soft, aligner_logprob, aligner_mas
 
     def _set_speaker_input(self, aux_input: Dict):
         d_vectors = aux_input.get("d_vectors", None)
@@ -525,52 +556,54 @@ class ForwardTTS(BaseTTS):
         """
         g = self._set_speaker_input(aux_input)
         # compute sequence masks
-        y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
-        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
+        y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()  # [B, 1, T_max]
         # encoder pass
-        o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
+        x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g)  # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
         # duration predictor pass
         if self.args.detach_duration_predictor:
-            o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
+            o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
         else:
-            o_dr_log = self.duration_predictor(o_en, x_mask)
+            o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g)  # [B, 1, T_max]
         o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
         # generate attn mask from predicted durations
-        o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
+        dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask)  # [B, T_max, T_max2']
         # aligner
-        o_alignment_dur = None
-        alignment_soft = None
-        alignment_logprob = None
-        alignment_mas = None
+        aligner_durations = None
+        aligner_soft = None
+        aligner_logprob = None
+        aligner_mas = None
         if self.use_aligner:
-            o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
-                x_emb, y, x_mask, y_mask
+            # TODO: try passing o_en instead of x_emb
+            aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner(
+                x=x_emb, y=y, x_mask=x_mask, y_mask=y_mask
             )
-            alignment_soft = alignment_soft.transpose(1, 2)
-            alignment_mas = alignment_mas.transpose(1, 2)
-            dr = o_alignment_dur
+            aligner_soft = aligner_soft.transpose(1, 2)  # [B, T_max, T_max2] -> [B, T_max2, T_max]
+            aligner_mas = aligner_mas.transpose(1, 2)  # [B, T_max, T_max2] -> [B, T_max2, T_max]
+            dr = aligner_durations  # [B, T_max]
         # pitch predictor pass
         o_pitch = None
         avg_pitch = None
         if self.args.use_pitch:
-            o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
+            o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(
+                o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
+            )
             o_en = o_en + o_pitch_emb
         # decoder pass
-        o_de, attn = self._forward_decoder(
-            o_en, dr, x_mask, y_lengths, g=None
-        )  # TODO: maybe pass speaker embedding (g) too
+        o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)  # [B, T_max2, C_de], [B, T_max2, T_max]
         outputs = {
             "model_outputs": o_de,  # [B, T, C]
+            "g": g,  # [B, C]
             "durations_log": o_dr_log.squeeze(1),  # [B, T]
             "durations": o_dr.squeeze(1),  # [B, T]
-            "attn_durations": o_attn,  # for visualization [B, T_en, T_de']
+            "attn_durations": dur_predictor_attn,  # for visualization [B, T_en, T_de']
             "pitch_avg": o_pitch,
             "pitch_avg_gt": avg_pitch,
             "alignments": attn,  # [B, T_de, T_en]
-            "alignment_soft": alignment_soft,
-            "alignment_mas": alignment_mas,
-            "o_alignment_dur": o_alignment_dur,
-            "alignment_logprob": alignment_logprob,
+            "aligner_soft": aligner_soft,
+            "aligner_mas": aligner_mas,
+            "aligner_durations": aligner_durations,
+            "aligner_logprob": aligner_logprob,
             "x_mask": x_mask,
             "y_mask": y_mask,
         }
@@ -593,7 +626,7 @@ class ForwardTTS(BaseTTS):
         x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
         x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
         # encoder pass
-        o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
+        _, x_mask, g, o_en = self._forward_encoder(x, x_mask, g)
         # duration predictor pass
         o_dr_log = self.duration_predictor(o_en, x_mask)
         o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
@@ -610,6 +643,7 @@ class ForwardTTS(BaseTTS):
             "alignments": attn,
             "pitch": o_pitch,
             "durations_log": o_dr_log,
+            "g": g,
         }
         return outputs
 
@@ -630,7 +664,7 @@ class ForwardTTS(BaseTTS):
         )
         # use aligner's output as the duration target
         if self.use_aligner:
-            durations = outputs["o_alignment_dur"]
+            durations = outputs["aligner_durations"]
         # use float32 in AMP
         with autocast(enabled=False):
             # compute loss
@@ -643,9 +677,9 @@ class ForwardTTS(BaseTTS):
                 pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
                 pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
                 input_lens=text_lengths,
-                alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
-                alignment_soft=outputs["alignment_soft"],
-                alignment_hard=outputs["alignment_mas"],
+                aligner_logprob=outputs["aligner_logprob"] if self.use_aligner else None,
+                aligner_soft=outputs["aligner_soft"],
+                aligner_hard=outputs["aligner_mas"],
                 binary_loss_weight=self.binary_loss_weight,
             )
             # compute duration error
@@ -655,7 +689,7 @@ class ForwardTTS(BaseTTS):
 
         return outputs, loss_dict
 
-    def _create_logs(self, batch, outputs, ap):
+    def create_logs(self, batch, outputs, ap):
         """Create common logger outputs."""
         model_outputs = outputs["model_outputs"]
         alignments = outputs["alignments"]
@@ -694,7 +728,7 @@ class ForwardTTS(BaseTTS):
     def train_log(
         self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
     ) -> None:  # pylint: disable=no-self-use
-        figures, audios = self._create_logs(batch, outputs, self.ap)
+        figures, audios = self.create_logs(batch, outputs, self.ap)
         logger.train_figures(steps, figures)
         logger.train_audios(steps, audios, self.ap.sample_rate)
 
@@ -702,7 +736,7 @@ class ForwardTTS(BaseTTS):
         return self.train_step(batch, criterion)
 
     def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
-        figures, audios = self._create_logs(batch, outputs, self.ap)
+        figures, audios = self.create_logs(batch, outputs, self.ap)
         logger.eval_figures(steps, figures)
         logger.eval_audios(steps, audios, self.ap.sample_rate)