diff --git a/TTS/VERSION b/TTS/VERSION
index 144996ed..6dd46024 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.20.3
+0.20.4
diff --git a/TTS/cs_api.py b/TTS/cs_api.py
index c45f9d08..9dc6c30d 100644
--- a/TTS/cs_api.py
+++ b/TTS/cs_api.py
@@ -82,7 +82,6 @@ class CS_API:
         },
     }
 
-
     SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"]
 
     def __init__(self, api_token=None, model="XTTS"):
@@ -308,7 +307,11 @@ if __name__ == "__main__":
     print(api.list_speakers_as_tts_models())
 
     ts = time.time()
-    wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name)
+    wav, sr = api.tts(
+        "It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name
+    )
     print(f" [i] XTTS took {time.time() - ts:.2f}s")
 
-    filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav")
+    filepath = api.tts_to_file(
+        text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav"
+    )
diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py
index 2d3edaf4..e8ab07da 100644
--- a/TTS/tts/configs/xtts_config.py
+++ b/TTS/tts/configs/xtts_config.py
@@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
             Defaults to `16`.
 
         gpt_cond_len (int):
-            Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
+            Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
+
+        gpt_cond_chunk_len (int):
+            Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
+            latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
+            If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
 
         max_ref_len (int):
             Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
@@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
     num_gpt_outputs: int = 1
 
     # cloning
-    gpt_cond_len: int = 3
+    gpt_cond_len: int = 12
+    gpt_cond_chunk_len: int = 4
     max_ref_len: int = 10
     sound_norm_refs: bool = False
diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py
index 2166eebb..c70888df 100644
--- a/TTS/tts/layers/tortoise/dpm_solver.py
+++ b/TTS/tts/layers/tortoise/dpm_solver.py
@@ -562,15 +562,21 @@ class DPM_Solver:
         if order == 3:
             K = steps // 3 + 1
             if steps % 3 == 0:
-                orders = [3,] * (
+                orders = [
+                    3,
+                ] * (
                     K - 2
                 ) + [2, 1]
             elif steps % 3 == 1:
-                orders = [3,] * (
+                orders = [
+                    3,
+                ] * (
                     K - 1
                 ) + [1]
             else:
-                orders = [3,] * (
+                orders = [
+                    3,
+                ] * (
                     K - 1
                 ) + [2]
         elif order == 2:
@@ -581,7 +587,9 @@ class DPM_Solver:
                 ] * K
             else:
                 K = steps // 2 + 1
-                orders = [2,] * (
+                orders = [
+                    2,
+                ] * (
                     K - 1
                 ) + [1]
         elif order == 1:
@@ -1440,7 +1448,10 @@ class DPM_Solver:
                         model_prev_list[-1] = self.model_fn(x, t)
             elif method in ["singlestep", "singlestep_fixed"]:
                 if method == "singlestep":
-                    (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
+                    (
+                        timesteps_outer,
+                        orders,
+                    ) = self.get_orders_and_timesteps_for_singlestep_solver(
                         steps=steps,
                         order=order,
                         skip_type=skip_type,
@@ -1548,4 +1559,4 @@ def expand_dims(v, dims):
     Returns:
         a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
     """
-    return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
+    return v[(...,) + (None,) * (dims - 1)]
diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py
index 683104d8..612da260 100644
--- a/TTS/tts/layers/xtts/gpt.py
+++ b/TTS/tts/layers/xtts/gpt.py
@@ -128,6 +128,7 @@ class GPT(nn.Module):
         self.heads = heads
         self.model_dim = model_dim
         self.max_conditioning_inputs = max_conditioning_inputs
+        self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2
         self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
         self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
         self.max_prompt_tokens = max_prompt_tokens
@@ -598,7 +599,7 @@ class GPT(nn.Module):
             bos_token_id=self.start_audio_token,
             pad_token_id=self.stop_audio_token,
             eos_token_id=self.stop_audio_token,
-            max_length=self.max_mel_tokens,
+            max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1],
             **hf_generate_kwargs,
         )
         if "return_dict_in_generate" in hf_generate_kwargs:
@@ -611,7 +612,7 @@ class GPT(nn.Module):
             bos_token_id=self.start_audio_token,
             pad_token_id=self.stop_audio_token,
             eos_token_id=self.stop_audio_token,
-            max_length=self.max_mel_tokens,
+            max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1],
             do_stream=True,
             **hf_generate_kwargs,
         )
diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index edb09042..211d0a93 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -1,6 +1,7 @@
 import json
 import os
 import re
+from functools import cached_property
 
 import pypinyin
 import torch
@@ -8,7 +9,6 @@ from hangul_romanize import Transliter
 from hangul_romanize.rule import academic
 from num2words import num2words
 from tokenizers import Tokenizer
-from functools import cached_property
 
 from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
 
@@ -560,19 +560,22 @@ class VoiceBpeTokenizer:
     @cached_property
     def katsu(self):
         import cutlet
+
         return cutlet.Cutlet()
-    
+
     def check_input_length(self, txt, lang):
         limit = self.char_limits.get(lang, 250)
         if len(txt) > limit:
-            print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")
+            print(
+                f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
+            )
 
     def preprocess_text(self, txt, lang):
         if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}:
             txt = multilingual_cleaners(txt, lang)
             if lang in {"zh", "zh-cn"}:
                 txt = chinese_transliterate(txt)
-        elif lang == "ja":                
+        elif lang == "ja":
             txt = japanese_cleaners(txt, self.katsu)
         elif lang == "ko":
             txt = korean_cleaners(txt)
diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py
index 8cb90ad0..2f958cb5 100644
--- a/TTS/tts/layers/xtts/trainer/dataset.py
+++ b/TTS/tts/layers/xtts/trainer/dataset.py
@@ -5,6 +5,7 @@ import sys
 import torch
 import torch.nn.functional as F
 import torch.utils.data
+
 from TTS.tts.models.xtts import load_audio
 
 torch.set_num_threads(1)
diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index f41bcfb9..b277c3ac 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -255,39 +255,57 @@ class Xtts(BaseTTS):
         return next(self.parameters()).device
 
     @torch.inference_mode()
-    def get_gpt_cond_latents(self, audio, sr, length: int = 3):
+    def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
         """Compute the conditioning latents for the GPT model from the given audio.
 
         Args:
             audio (tensor): audio tensor.
             sr (int): Sample rate of the audio.
-            length (int): Length of the audio in seconds. Defaults to 3.
+            length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
+            chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
+                is being used without chunking. It must be < `length`. Defaults to 6.
         """
         if sr != 22050:
             audio = torchaudio.functional.resample(audio, sr, 22050)
-        audio = audio[:, : 22050 * length]
+        if length > 0:
+            audio = audio[:, : 22050 * length]
         if self.args.gpt_use_perceiver_resampler:
-            n_fft = 2048
-            hop_length = 256
-            win_length = 1024
+            style_embs = []
+            for i in range(0, audio.shape[1], 22050 * chunk_length):
+                audio_chunk = audio[:, i : i + 22050 * chunk_length]
+                mel_chunk = wav_to_mel_cloning(
+                    audio_chunk,
+                    mel_norms=self.mel_stats.cpu(),
+                    n_fft=2048,
+                    hop_length=256,
+                    win_length=1024,
+                    power=2,
+                    normalized=False,
+                    sample_rate=22050,
+                    f_min=0,
+                    f_max=8000,
+                    n_mels=80,
+                )
+                style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
+                style_embs.append(style_emb)
+
+            # mean style embedding
+            cond_latent = torch.stack(style_embs).mean(dim=0)
         else:
-            n_fft = 4096
-            hop_length = 1024
-            win_length = 4096
-        mel = wav_to_mel_cloning(
-            audio,
-            mel_norms=self.mel_stats.cpu(),
-            n_fft=n_fft,
-            hop_length=hop_length,
-            win_length=win_length,
-            power=2,
-            normalized=False,
-            sample_rate=22050,
-            f_min=0,
-            f_max=8000,
-            n_mels=80,
-        )
-        cond_latent = self.gpt.get_style_emb(mel.to(self.device))
+            mel = wav_to_mel_cloning(
+                audio,
+                mel_norms=self.mel_stats.cpu(),
+                n_fft=4096,
+                hop_length=1024,
+                win_length=4096,
+                power=2,
+                normalized=False,
+                sample_rate=22050,
+                f_min=0,
+                f_max=8000,
+                n_mels=80,
+            )
+            cond_latent = self.gpt.get_style_emb(mel.to(self.device))
         return cond_latent.transpose(1, 2)
 
     @torch.inference_mode()
@@ -323,12 +341,24 @@ class Xtts(BaseTTS):
     def get_conditioning_latents(
         self,
         audio_path,
+        max_ref_length=30,
         gpt_cond_len=6,
-        max_ref_length=10,
+        gpt_cond_chunk_len=6,
         librosa_trim_db=None,
         sound_norm_refs=False,
-        load_sr=24000,
+        load_sr=22050,
     ):
+        """Get the conditioning latents for the GPT model from the given audio.
+
+        Args:
+            audio_path (str or List[str]): Path to reference audio file(s).
+            max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
+            gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
+            gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
+            librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
+            sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
+            load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
+        """
         # deal with multiples references
         if not isinstance(audio_path, list):
             audio_paths = [audio_path]
@@ -339,24 +369,24 @@ class Xtts(BaseTTS):
         audios = []
         speaker_embedding = None
         for file_path in audio_paths:
-            # load the audio in 24khz to avoid issued with multiple sr references
             audio = load_audio(file_path, load_sr)
             audio = audio[:, : load_sr * max_ref_length].to(self.device)
-            if audio.shape[0] > 1:
-                audio = audio.mean(0, keepdim=True)
             if sound_norm_refs:
                 audio = (audio / torch.abs(audio).max()) * 0.75
             if librosa_trim_db is not None:
                 audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
 
+            # compute latents for the decoder
             speaker_embedding = self.get_speaker_embedding(audio, load_sr)
             speaker_embeddings.append(speaker_embedding)
 
             audios.append(audio)
 
-        # use a merge of all references for gpt cond latents
+        # merge all the audios and compute the latents for the gpt
         full_audio = torch.cat(audios, dim=-1)
-        gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len)  # [1, 1024, T]
+        gpt_cond_latents = self.get_gpt_cond_latents(
+            full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
+        )  # [1, 1024, T]
 
         if speaker_embeddings:
             speaker_embedding = torch.stack(speaker_embeddings)
@@ -397,6 +427,7 @@ class Xtts(BaseTTS):
             "top_k": config.top_k,
             "top_p": config.top_p,
             "gpt_cond_len": config.gpt_cond_len,
+            "gpt_cond_chunk_len": config.gpt_cond_chunk_len,
             "max_ref_len": config.max_ref_len,
             "sound_norm_refs": config.sound_norm_refs,
         }
@@ -417,7 +448,8 @@ class Xtts(BaseTTS):
         top_p=0.85,
         do_sample=True,
         # Cloning
-        gpt_cond_len=6,
+        gpt_cond_len=30,
+        gpt_cond_chunk_len=6,
         max_ref_len=10,
         sound_norm_refs=False,
         **hf_generate_kwargs,
@@ -448,7 +480,10 @@ class Xtts(BaseTTS):
                 (aka boring) outputs. Defaults to 0.8.
 
             gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
-                else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
+                else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
+
+            gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
+                If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
 
             hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
                 transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
@@ -461,6 +496,7 @@ class Xtts(BaseTTS):
         (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
             audio_path=ref_audio_path,
             gpt_cond_len=gpt_cond_len,
+            gpt_cond_chunk_len=gpt_cond_chunk_len,
             max_ref_length=max_ref_len,
             sound_norm_refs=sound_norm_refs,
         )
@@ -566,7 +602,7 @@ class Xtts(BaseTTS):
             if overlap_len > len(wav_chunk):
                 # wav_chunk is smaller than overlap_len, pass on last wav_gen
                 if wav_gen_prev is not None:
-                    wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
+                    wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
                 else:
                     # not expecting will hit here as problem happens on last chunk
                     wav_chunk = wav_gen[-overlap_len:]
@@ -576,7 +612,7 @@ class Xtts(BaseTTS):
                 crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
                 wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
                 wav_chunk[:overlap_len] += crossfade_wav
-                
+
         wav_overlap = wav_gen[-overlap_len:]
         wav_gen_prev = wav_gen
         return wav_chunk, wav_gen_prev, wav_overlap
diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py
index 12c547d6..b8b9a4e3 100644
--- a/tests/xtts_tests/test_xtts_gpt_train.py
+++ b/tests/xtts_tests/test_xtts_gpt_train.py
@@ -60,7 +60,9 @@ XTTS_CHECKPOINT = None  # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
 
 
 # Training sentences generations
-SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"]  # speaker reference to be used in training test sentences
+SPEAKER_REFERENCE = [
+    "tests/data/ljspeech/wavs/LJ001-0002.wav"
+]  # speaker reference to be used in training test sentences
 LANGUAGE = config_dataset.language
 
 
diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py
index b19b7210..6663433c 100644
--- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py
+++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py
@@ -58,7 +58,9 @@ XTTS_CHECKPOINT = None  # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_s
 
 
 # Training sentences generations
-SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"]  # speaker reference to be used in training test sentences
+SPEAKER_REFERENCE = [
+    "tests/data/ljspeech/wavs/LJ001-0002.wav"
+]  # speaker reference to be used in training test sentences
 LANGUAGE = config_dataset.language