From b4c82685a73f136fb8ecc0ca2da33eacae31ac29 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <egolge@coqui.ai>
Date: Mon, 24 Jul 2023 12:33:05 +0200
Subject: [PATCH 01/37] Add model entries

---
 TTS/.models.json | 23 +++++++++++++++++++++++
 1 file changed, 23 insertions(+)

diff --git a/TTS/.models.json b/TTS/.models.json
index 69ac7514..02873e7b 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -715,6 +715,18 @@
                     "license": "Apache 2.0"
                 }
             }
+        },
+        "be": {
+            "common-voice": {
+                "glow-tts":{
+                    "description": "Belarusian GlowTTS model created by @alex73 (Github).",
+                    "hf_url":"",
+                    "default_vocoder": "vocoder_models/be/common-voice/hifigan",
+                    "commit": "c0aabb85",
+                    "license": "CC-BY-SA 4.0",
+                    "contact": "alex73mail@gmail.com"
+                }
+            }
         }
     },
     "vocoder_models": {
@@ -866,6 +878,17 @@
                     "commit": null
                 }
             }
+        },
+        "be": {
+            "common-voice": {
+                "hifigan": {
+                    "hf_url": "https://huggingface.co/coqui/hifigan-be",
+                    "description": "Belarusian HiFiGAN model created by @alex73 (Github).",
+                    "author": "@alex73",
+                    "license": "CC-BY-SA 4.0",
+                    "commit": "c0aabb85"
+                }
+            }
         }
     },
     "voice_conversion_models": {

From 562a9509f253a7a40d8769940a94096aacbb3fc2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <egolge@coqui.ai>
Date: Mon, 4 Sep 2023 13:57:03 +0200
Subject: [PATCH 02/37] Add BE model

---
 TTS/.models.json    |  4 ++--
 TTS/utils/manage.py | 15 +++++++++------
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/TTS/.models.json b/TTS/.models.json
index 02873e7b..c39c39fc 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -720,7 +720,7 @@
             "common-voice": {
                 "glow-tts":{
                     "description": "Belarusian GlowTTS model created by @alex73 (Github).",
-                    "hf_url":"",
+                    "github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip",
                     "default_vocoder": "vocoder_models/be/common-voice/hifigan",
                     "commit": "c0aabb85",
                     "license": "CC-BY-SA 4.0",
@@ -882,7 +882,7 @@
         "be": {
             "common-voice": {
                 "hifigan": {
-                    "hf_url": "https://huggingface.co/coqui/hifigan-be",
+                    "github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip",
                     "description": "Belarusian HiFiGAN model created by @alex73 (Github).",
                     "author": "@alex73",
                     "license": "CC-BY-SA 4.0",
diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index 70d35228..be393adb 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -468,13 +468,16 @@ class ModelManager(object):
             print(f" > Error: Bad zip file - {file_url}")
             raise zipfile.BadZipFile  # pylint: disable=raise-missing-from
         # move the files to the outer path
-        for file_path in z.namelist()[1:]:
+        for file_path in z.namelist():
             src_path = os.path.join(output_folder, file_path)
-            dst_path = os.path.join(output_folder, os.path.basename(file_path))
-            if src_path != dst_path:
-                copyfile(src_path, dst_path)
-        # remove the extracted folder
-        rmtree(os.path.join(output_folder, z.namelist()[0]))
+            if os.path.isfile(src_path):
+                dst_path = os.path.join(output_folder, os.path.basename(file_path))
+                if src_path != dst_path:
+                    copyfile(src_path, dst_path)
+        # remove redundant (hidden or not) folders
+        for file_path in z.namelist():
+            if os.path.isdir(os.path.join(output_folder, file_path)):
+                rmtree(os.path.join(output_folder, file_path))
 
     @staticmethod
     def _download_tar_file(file_url, output_folder, progress_bar):

From 9533f8656cc93ce6fb103d18cb8cf2f8fc0f22bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <egolge@coqui.ai>
Date: Mon, 4 Sep 2023 13:58:37 +0200
Subject: [PATCH 03/37] Make style

---
 TTS/bin/synthesize.py                                   | 2 +-
 TTS/tts/utils/text/belarusian/phonemizer.py             | 5 ++++-
 TTS/tts/utils/text/phonemizers/__init__.py              | 2 +-
 TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py | 2 +-
 recipes/bel-alex73/train_glowtts.py                     | 2 +-
 tests/text_tests/test_belarusian_phonemizer.py          | 5 +++--
 6 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index 5ded3067..6adb9f03 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -392,7 +392,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
     if args.encoder_path is not None:
         encoder_path = args.encoder_path
         encoder_config_path = args.encoder_config_path
-    
+
     device = args.device
     if args.use_cuda:
         device = "cuda"
diff --git a/TTS/tts/utils/text/belarusian/phonemizer.py b/TTS/tts/utils/text/belarusian/phonemizer.py
index 3c07a209..1922577e 100644
--- a/TTS/tts/utils/text/belarusian/phonemizer.py
+++ b/TTS/tts/utils/text/belarusian/phonemizer.py
@@ -8,7 +8,9 @@ def init():
         import jpype
         import jpype.imports
     except ModuleNotFoundError:
-        raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.")
+        raise ModuleNotFoundError(
+            "Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`."
+        )
 
     try:
         jar_path = os.environ["BEL_FANETYKA_JAR"]
@@ -31,4 +33,5 @@ def belarusian_text_to_phonemes(text: str) -> str:
         init()
 
     from org.alex73.fanetyka.impl import FanetykaText
+
     return str(FanetykaText(finder, text).ipa)
diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py
index 638184fd..f9a0340c 100644
--- a/TTS/tts/utils/text/phonemizers/__init__.py
+++ b/TTS/tts/utils/text/phonemizers/__init__.py
@@ -1,6 +1,6 @@
 from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer
-from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
 from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
+from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
 from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
 from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
 from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer
diff --git a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
index fb620766..e5fcab6e 100644
--- a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
+++ b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py
@@ -1,7 +1,7 @@
 from typing import Dict
 
-from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
 from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
+from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
 
 _DEF_BE_PUNCS = ",!."  # TODO
 
diff --git a/recipes/bel-alex73/train_glowtts.py b/recipes/bel-alex73/train_glowtts.py
index 24b62d79..74866be7 100644
--- a/recipes/bel-alex73/train_glowtts.py
+++ b/recipes/bel-alex73/train_glowtts.py
@@ -60,7 +60,7 @@ config = GlowTTSConfig(
     output_path=output_path,
     add_blank=True,
     datasets=[dataset_config],
-#    characters=characters,
+    #    characters=characters,
     enable_eos_bos_chars=True,
     mixed_precision=False,
     save_step=10000,
diff --git a/tests/text_tests/test_belarusian_phonemizer.py b/tests/text_tests/test_belarusian_phonemizer.py
index 278ee8be..76ba4667 100644
--- a/tests/text_tests/test_belarusian_phonemizer.py
+++ b/tests/text_tests/test_belarusian_phonemizer.py
@@ -1,6 +1,6 @@
 import os
-import warnings
 import unittest
+import warnings
 
 from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
 
@@ -17,7 +17,8 @@ class TestText(unittest.TestCase):
         except KeyError:
             warnings.warn(
                 "You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer",
-                Warning)
+                Warning,
+            )
             return
 
         for line in _TEST_CASES.strip().split("\n"):

From 9d0b76ce2393455dbe9523288d5c341953350fb5 Mon Sep 17 00:00:00 2001
From: Eren G??lge <egolge@coqui.ai>
Date: Thu, 14 Sep 2023 17:51:40 +0200
Subject: [PATCH 04/37] Check env var for COQUI_TOS_AGREED

---
 TTS/utils/manage.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index 6e082297..f38d23f8 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -315,7 +315,7 @@ class ModelManager(object):
         """Check if the user has agreed to the terms of service"""
         if "tos_required" in model_item and model_item["tos_required"]:
             tos_path = os.path.join(model_full_path, "tos_agreed.txt")
-            if os.path.exists(tos_path):
+            if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
                 return True
             return False
         return True

From aa8fa4756e9130e536f0dfc610617db103a43efd Mon Sep 17 00:00:00 2001
From: Eren G??lge <egolge@coqui.ai>
Date: Thu, 14 Sep 2023 17:52:44 +0200
Subject: [PATCH 05/37] Bump up to v0.17.4

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index c3d16c16..884e9604 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.2
+0.17.3

From f829bf50f8cd261a2ef23356cb83f0ac6c3ff6a8 Mon Sep 17 00:00:00 2001
From: Reuben Morais <reuben.morais@gmail.com>
Date: Fri, 15 Sep 2023 16:40:34 +0200
Subject: [PATCH 06/37] Bump version to v0.17.4 (really)

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 884e9604..44e33a41 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.3
+0.17.4

From 6916aa37ab202d8db4dae4479bcacf38b2980fe5 Mon Sep 17 00:00:00 2001
From: Julian Weber <julian.weber@hotmail.fr>
Date: Tue, 19 Sep 2023 20:54:12 +0200
Subject: [PATCH 07/37] Fix fsspec requirement (#2970)

* Fix requirment for fsspec

* Use the right version this time
---
 requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index ae22b333..7de6119b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,7 +13,7 @@ inflect==5.6.0
 tqdm
 anyascii
 pyyaml
-fsspec>=2021.04.0
+fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
 aiohttp
 packaging
 # deps for examples

From 335ae63e01c60fa1f0fb11152ad13426fff6f561 Mon Sep 17 00:00:00 2001
From: Omar Sanseviero <osanseviero@gmail.com>
Date: Wed, 20 Sep 2023 00:57:09 +0200
Subject: [PATCH 08/37] Add coqui blog post (#2949)

* Update README.md

* Update README.md

---------

Co-authored-by: Edresson Casanova <edresson1@gmail.com>
---
 README.md | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index 474f5499..934e9443 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
 
 ## 🐸Coqui.ai News
-- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with uncontrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
+- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
+- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
 - 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
 - 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html)
 - 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
@@ -111,7 +112,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
 - Delightful TTS: [paper](https://arxiv.org/abs/2110.12612)
 
 ### End-to-End Models
-- ⓍTTS: [blog]()
+- ⓍTTS: [blog](https://coqui.ai/blog/tts/open_xtts)
 - VITS: [paper](https://arxiv.org/pdf/2106.06103)
 - 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418)
 - 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts)

From da8b6bbce1040ce8a4a58b959faec14e969294a0 Mon Sep 17 00:00:00 2001
From: loupzeur <20752997+loupzeur@users.noreply.github.com>
Date: Wed, 20 Sep 2023 09:57:02 +0200
Subject: [PATCH 09/37] fix: xtts not taking into account device flag (#2951)

* fix: xtts not taking into account device flag

* Style changes

---------

Co-authored-by: Julian Weber <julian.weber@hotmail.fr>
---
 TTS/tts/models/xtts.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index 0836870e..a23a0f5f 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -642,7 +642,7 @@ class Xtts(BaseTTS):
         self.init_models()
         if eval:
             self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
-        self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
+        self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
 
         if eval:
             self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)

From a2a15392e02bf3e430e21c62923c8a5daa7828a5 Mon Sep 17 00:00:00 2001
From: WeberJulian <julian.weber@hotmail.fr>
Date: Mon, 25 Sep 2023 11:01:36 +0200
Subject: [PATCH 10/37] fix package versions

---
 requirements.txt | 35 ++++++++++++++++++-----------------
 1 file changed, 18 insertions(+), 17 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 7de6119b..843aaa53 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,26 +7,27 @@ torch>=1.7
 torchaudio
 soundfile
 librosa==0.10.0.*
+scikit-learn==1.3.0
 numba==0.55.1;python_version<"3.9"
 numba==0.57.0;python_version>="3.9"
 inflect==5.6.0
-tqdm
-anyascii
-pyyaml
+tqdm==4.64.1
+anyascii==0.3.1
+pyyaml==6.0.1
 fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
-aiohttp
-packaging
+aiohttp==3.8.5
+packaging==23.1
 # deps for examples
-flask
+flask==2.2.2
 # deps for inference
-pysbd
+pysbd==0.3.4
 # deps for notebooks
-umap-learn==0.5.1
-pandas
+#umap-learn==0.5.1
+pandas==1.4.3
 # deps for training
-matplotlib
+matplotlib==3.7.3
 # coqui stack
-trainer
+trainer==0.0.31
 # config management
 coqpit>=0.0.16
 # chinese g2p deps
@@ -35,18 +36,18 @@ pypinyin
 # gruut+supported langs
 gruut[de,es,fr]==2.2.3
 # deps for korean
-jamo
-nltk
+jamo==0.4.1
+nltk==3.8.1
 g2pkk>=0.1.1
 # deps for bangla
 bangla==0.0.2
-bnnumerizer
+bnnumerizer==0.0.2
 bnunicodenormalizer==0.1.1
 #deps for tortoise
 k_diffusion
-einops
-transformers
+einops==0.6.1
+transformers==4.30.2
 #deps for bark
-encodec
+encodec==0.1.1
 # deps for XTTS
 unidecode

From f1c1d14c541c1ff110cbaa461563085bfaaee7e0 Mon Sep 17 00:00:00 2001
From: WeberJulian <julian.weber@hotmail.fr>
Date: Mon, 25 Sep 2023 11:12:01 +0200
Subject: [PATCH 11/37] Add back umap

---
 requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index 843aaa53..09651c2a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,7 +22,7 @@ flask==2.2.2
 # deps for inference
 pysbd==0.3.4
 # deps for notebooks
-#umap-learn==0.5.1
+umap-learn==0.5.4
 pandas==1.4.3
 # deps for training
 matplotlib==3.7.3

From bbfdfbffdf5b4b1c7befb4e6a49d900621f701d5 Mon Sep 17 00:00:00 2001
From: WeberJulian <julian.weber@hotmail.fr>
Date: Mon, 25 Sep 2023 11:46:38 +0200
Subject: [PATCH 12/37] Update transformers to latest

---
 requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index 09651c2a..c27a3bc3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -46,7 +46,7 @@ bnunicodenormalizer==0.1.1
 #deps for tortoise
 k_diffusion
 einops==0.6.1
-transformers==4.30.2
+transformers==4.33.2
 #deps for bark
 encodec==0.1.1
 # deps for XTTS

From 089ad66df2fec5e7bbc16ace6d8eeb2605c2ef7b Mon Sep 17 00:00:00 2001
From: WeberJulian <julian.weber@hotmail.fr>
Date: Mon, 25 Sep 2023 17:00:41 +0200
Subject: [PATCH 13/37] Lower the versions constraints

---
 requirements.txt | 42 +++++++++++++++++++++---------------------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index c27a3bc3..76071439 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,29 +5,29 @@ cython==0.29.30
 scipy>=1.11.2
 torch>=1.7
 torchaudio
-soundfile
-librosa==0.10.0.*
+soundfile==0.12.*
+librosa==0.10.*
 scikit-learn==1.3.0
 numba==0.55.1;python_version<"3.9"
 numba==0.57.0;python_version>="3.9"
-inflect==5.6.0
-tqdm==4.64.1
-anyascii==0.3.1
-pyyaml==6.0.1
+inflect==5.6.*
+tqdm==4.64.*
+anyascii==0.3.*
+pyyaml==6.*
 fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
-aiohttp==3.8.5
+aiohttp==3.8.*
 packaging==23.1
 # deps for examples
-flask==2.2.2
+flask==2.*
 # deps for inference
 pysbd==0.3.4
 # deps for notebooks
-umap-learn==0.5.4
-pandas==1.4.3
+umap-learn==0.5.*
+pandas==1.4.*
 # deps for training
-matplotlib==3.7.3
+matplotlib==3.7.*
 # coqui stack
-trainer==0.0.31
+trainer
 # config management
 coqpit>=0.0.16
 # chinese g2p deps
@@ -36,18 +36,18 @@ pypinyin
 # gruut+supported langs
 gruut[de,es,fr]==2.2.3
 # deps for korean
-jamo==0.4.1
-nltk==3.8.1
+jamo
+nltk
 g2pkk>=0.1.1
 # deps for bangla
-bangla==0.0.2
-bnnumerizer==0.0.2
-bnunicodenormalizer==0.1.1
+bangla
+bnnumerizer
+bnunicodenormalizer
 #deps for tortoise
 k_diffusion
-einops==0.6.1
-transformers==4.33.2
+einops==0.6.*
+transformers==4.33.*
 #deps for bark
-encodec==0.1.1
+encodec==0.1.*
 # deps for XTTS
-unidecode
+unidecode==1.3.*

From 0b95b88f138b58982fe078b13ab56efcb0f751f0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <erogol@hotmail.com>
Date: Mon, 25 Sep 2023 18:16:45 +0200
Subject: [PATCH 14/37] Bum up to v0.17.5

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 44e33a41..8b5334dc 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.4
+0.17.5

From 5c047cf30466620ba98a57b145276821b08df7ea Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Mon, 25 Sep 2023 15:03:51 +0300
Subject: [PATCH 15/37] Ensure `tts` CLI tool readme and usage help is in sync

---
 README.md                      | 136 +++++++++++++---------
 TTS/bin/synthesize.py          | 201 +++++++++++++++++++--------------
 scripts/sync_readme.py         |  32 ++++++
 tests/aux_tests/test_readme.py |   9 ++
 4 files changed, 236 insertions(+), 142 deletions(-)
 create mode 100644 scripts/sync_readme.py
 create mode 100644 tests/aux_tests/test_readme.py

diff --git a/README.md b/README.md
index 934e9443..720585db 100644
--- a/README.md
+++ b/README.md
@@ -294,99 +294,123 @@ api.tts_with_vc_to_file(
 ```
 
 ### Command-line `tts`
+
+<!-- begin-tts-readme -->
+
+Synthesize speech on command line.
+
+You can either use your trained model or choose a model from the provided list.
+
+If you don't specify any models, then it uses LJSpeech based English model.
+
 #### Single Speaker Models
 
 - List provided models:
 
-    ```
-    $ tts --list_models
-    ```
+  ```
+  $ tts --list_models
+  ```
+
 - Get model info (for both tts_models and vocoder_models):
-    - Query by type/name:
-        The model_info_by_name uses the name as it from the --list_models.
-        ```
-        $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
-        ```
-        For example:
 
-        ```
-        $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
-        ```
-        ```
-        $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
-        ```
-    - Query by type/idx:
-        The model_query_idx uses the corresponding idx from --list_models.
-        ```
-        $ tts --model_info_by_idx "<model_type>/<model_query_idx>"
-        ```
-        For example:
+  - Query by type/name:
+    The model_info_by_name uses the name as it from the --list_models.
+    ```
+    $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
+    ```
+    For example:
+    ```
+    $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
+    $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
+    ```
+  - Query by type/idx:
+    The model_query_idx uses the corresponding idx from --list_models.
 
-        ```
-        $ tts --model_info_by_idx tts_models/3
-        ```
+    ```
+    $ tts --model_info_by_idx "<model_type>/<model_query_idx>"
+    ```
+
+    For example:
+
+    ```
+    $ tts --model_info_by_idx tts_models/3
+    ```
+
+  - Query info for model info by full name:
+    ```
+    $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
+    ```
 
 - Run TTS with default models:
 
-    ```
-    $ tts --text "Text for TTS" --out_path output/path/speech.wav
-    ```
+  ```
+  $ tts --text "Text for TTS" --out_path output/path/speech.wav
+  ```
 
 - Run a TTS model with its default vocoder model:
 
-    ```
-    $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
-    ```
+  ```
+  $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
+  ```
+
   For example:
 
-    ```
-    $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
-    ```
+  ```
+  $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
+  ```
 
 - Run with specific TTS and vocoder models from the list:
 
-    ```
-    $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
-    ```
+  ```
+  $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
+  ```
 
   For example:
 
-    ```
-    $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
-    ```
-
+  ```
+  $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
+  ```
 
 - Run your own TTS model (Using Griffin-Lim Vocoder):
 
-    ```
-    $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
-    ```
+  ```
+  $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
+  ```
 
 - Run your own TTS and Vocoder models:
-    ```
-    $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
-        --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
-    ```
+
+  ```
+  $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
+      --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
+  ```
 
 #### Multi-speaker Models
 
 - List the available speakers and choose a <speaker_id> among them:
 
-    ```
-    $ tts --model_name "<language>/<dataset>/<model_name>"  --list_speaker_idxs
-    ```
+  ```
+  $ tts --model_name "<language>/<dataset>/<model_name>"  --list_speaker_idxs
+  ```
 
 - Run the multi-speaker TTS model with the target speaker ID:
 
-    ```
-    $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>"  --speaker_idx <speaker_id>
-    ```
+  ```
+  $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>"  --speaker_idx <speaker_id>
+  ```
 
 - Run your own multi-speaker TTS model:
 
-    ```
-    $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
-    ```
+  ```
+  $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
+  ```
+
+### Voice Conversion Models
+
+```
+$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
+```
+
+<!-- end-tts-readme -->
 
 ## Directory Structure
 ```
diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index e8de18b0..99fc2a58 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -12,6 +12,121 @@ from TTS.api import TTS
 from TTS.utils.manage import ModelManager
 from TTS.utils.synthesizer import Synthesizer
 
+description = """
+Synthesize speech on command line.
+
+You can either use your trained model or choose a model from the provided list.
+
+If you don't specify any models, then it uses LJSpeech based English model.
+
+#### Single Speaker Models
+
+- List provided models:
+
+  ```
+  $ tts --list_models
+  ```
+
+- Get model info (for both tts_models and vocoder_models):
+
+  - Query by type/name:
+    The model_info_by_name uses the name as it from the --list_models.
+    ```
+    $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
+    ```
+    For example:
+    ```
+    $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
+    $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
+    ```
+  - Query by type/idx:
+    The model_query_idx uses the corresponding idx from --list_models.
+
+    ```
+    $ tts --model_info_by_idx "<model_type>/<model_query_idx>"
+    ```
+
+    For example:
+
+    ```
+    $ tts --model_info_by_idx tts_models/3
+    ```
+
+  - Query info for model info by full name:
+    ```
+    $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
+    ```
+
+- Run TTS with default models:
+
+  ```
+  $ tts --text "Text for TTS" --out_path output/path/speech.wav
+  ```
+
+- Run a TTS model with its default vocoder model:
+
+  ```
+  $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
+  ```
+
+  For example:
+
+  ```
+  $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
+  ```
+
+- Run with specific TTS and vocoder models from the list:
+
+  ```
+  $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
+  ```
+
+  For example:
+
+  ```
+  $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
+  ```
+
+- Run your own TTS model (Using Griffin-Lim Vocoder):
+
+  ```
+  $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
+  ```
+
+- Run your own TTS and Vocoder models:
+
+  ```
+  $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
+      --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
+  ```
+
+#### Multi-speaker Models
+
+- List the available speakers and choose a <speaker_id> among them:
+
+  ```
+  $ tts --model_name "<language>/<dataset>/<model_name>"  --list_speaker_idxs
+  ```
+
+- Run the multi-speaker TTS model with the target speaker ID:
+
+  ```
+  $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>"  --speaker_idx <speaker_id>
+  ```
+
+- Run your own multi-speaker TTS model:
+
+  ```
+  $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
+  ```
+
+### Voice Conversion Models
+
+```
+$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
+```
+"""
+
 
 def str2bool(v):
     if isinstance(v, bool):
@@ -24,92 +139,6 @@ def str2bool(v):
 
 
 def main():
-    description = """Synthesize speech on command line.
-
-You can either use your trained model or choose a model from the provided list.
-
-If you don't specify any models, then it uses LJSpeech based English model.
-
-## Example Runs
-
-### Single Speaker Models
-
-- List provided models:
-
-    ```
-    $ tts --list_models
-    ```
-
-- Query info for model info by idx:
-
-    ```
-    $ tts --model_info_by_idx "<model_type>/<model_query_idx>"
-    ```
-
-- Query info for model info by full name:
-
-    ```
-    $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
-    ```
-
-- Run TTS with default models:
-
-    ```
-    $ tts --text "Text for TTS"
-    ```
-
-- Run a TTS model with its default vocoder model:
-
-    ```
-    $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
-    ```
-
-- Run with specific TTS and vocoder models from the list:
-
-    ```
-    $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
-    ```
-
-- Run your own TTS model (Using Griffin-Lim Vocoder):
-
-    ```
-    $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
-    ```
-
-- Run your own TTS and Vocoder models:
-    ```
-    $ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav
-        --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
-    ```
-
-### Multi-speaker Models
-
-- List the available speakers and choose as <speaker_id> among them:
-
-    ```
-    $ tts --model_name "<language>/<dataset>/<model_name>"  --list_speaker_idxs
-    ```
-
-- Run the multi-speaker TTS model with the target speaker ID:
-
-    ```
-    $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>"  --speaker_idx <speaker_id>
-    ```
-
-- Run your own multi-speaker TTS model:
-
-    ```
-    $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
-    ```
-
-### Voice Conversion Models
-
-    ```
-    $ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
-    ```
-    """
-    # We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
-    # documentation in sync more easily.
     parser = argparse.ArgumentParser(
         description=description.replace("    ```\n", ""),
         formatter_class=RawTextHelpFormatter,
diff --git a/scripts/sync_readme.py b/scripts/sync_readme.py
new file mode 100644
index 00000000..58428681
--- /dev/null
+++ b/scripts/sync_readme.py
@@ -0,0 +1,32 @@
+import argparse
+from pathlib import Path
+
+
+def replace_between_markers(content, marker: str, replacement: str) -> str:
+    start_marker = f"<!-- begin-{marker} -->\n\n"
+    end_marker = f"\n\n<!-- end-{marker} -->\n"
+    start_index = content.index(start_marker) + len(start_marker)
+    end_index = content.index(end_marker)
+    content = content[:start_index] + replacement + content[end_index:]
+    return content
+
+
+def sync_readme():
+    ap = argparse.ArgumentParser()
+    ap.add_argument("--check", action="store_true", default=False)
+    args = ap.parse_args()
+    readme_path = Path(__file__).parent.parent / "README.md"
+    orig_content = readme_path.read_text()
+    from TTS.bin.synthesize import description
+
+    new_content = replace_between_markers(orig_content, "tts-readme", description.strip())
+    if args.check:
+        if orig_content != new_content:
+            print("README.md is out of sync; please edit TTS/bin/TTS_README.md and run scripts/sync_readme.py")
+            exit(42)
+    readme_path.write_text(new_content)
+    print("Updated README.md")
+
+
+if __name__ == "__main__":
+    sync_readme()
diff --git a/tests/aux_tests/test_readme.py b/tests/aux_tests/test_readme.py
new file mode 100644
index 00000000..32b26fc6
--- /dev/null
+++ b/tests/aux_tests/test_readme.py
@@ -0,0 +1,9 @@
+import subprocess
+import sys
+from pathlib import Path
+
+
+def test_readme_up_to_date():
+    root = Path(__file__).parent.parent.parent
+    sync_readme = root / "scripts" / "sync_readme.py"
+    subprocess.check_call([sys.executable, str(sync_readme), "--check"], cwd=root)

From 0a82f063cc19a6dbc79b41c81f643718ad875c03 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Mon, 25 Sep 2023 15:04:08 +0300
Subject: [PATCH 16/37] Late-import main TTS libraries in `tts` CLI

---
 TTS/bin/synthesize.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index 99fc2a58..5ff1181f 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -8,10 +8,6 @@ from argparse import RawTextHelpFormatter
 # pylint: disable=redefined-outer-name, unused-argument
 from pathlib import Path
 
-from TTS.api import TTS
-from TTS.utils.manage import ModelManager
-from TTS.utils.synthesizer import Synthesizer
-
 description = """
 Synthesize speech on command line.
 
@@ -339,6 +335,11 @@ def main():
     if not any(check_args):
         parser.parse_args(["-h"])
 
+    # Late-import to make things load faster
+    from TTS.api import TTS
+    from TTS.utils.manage import ModelManager
+    from TTS.utils.synthesizer import Synthesizer
+
     # load model manager
     path = Path(__file__).parent / "../.models.json"
     manager = ModelManager(path, progress_bar=args.progress_bar)

From 94c5fd07651af6164f16ae28a79cb7862fa2f659 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Tue, 26 Sep 2023 16:02:55 +0300
Subject: [PATCH 17/37] Remove unnecessary black exclude config

It seems to have been copy-pasted from the Black docs.
---
 pyproject.toml | 19 -------------------
 1 file changed, 19 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index 8544bb20..4f47dc10 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,25 +7,6 @@ max-line-length=120
 [tool.black]
 line-length = 120
 target-version = ['py39']
-exclude = '''
-
-(
-  /(
-      \.eggs         # exclude a few common directories in the
-    | \.git          # root of the project
-    | \.hg
-    | \.mypy_cache
-    | \.tox
-    | \.venv
-    | _build
-    | buck-out
-    | build
-    | dist
-  )/
-  | foo.py           # also separately exclude a file named foo.py in
-                     # the root of the project
-)
-'''
 
 [tool.isort]
 line_length = 120

From 8bb2d652cadea9a2ea08d6f8a4b785e4ca378a25 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Tue, 26 Sep 2023 20:41:26 +0300
Subject: [PATCH 18/37] pyproject.toml: loosen dependencies to avoid building
 from source

---
 pyproject.toml | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/pyproject.toml b/pyproject.toml
index 8544bb20..6e0afacc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,5 +1,11 @@
 [build-system]
-requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"]
+requires = [
+    "setuptools",
+    "wheel",
+    "cython~=0.29.30",
+    "numpy>=1.22.0",
+    "packaging",
+]
 
 [flake8]
 max-line-length=120

From 6277f09c5f8f695b21877aa795cd458615fed479 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Tue, 26 Sep 2023 20:43:59 +0300
Subject: [PATCH 19/37] requirements.txt: loosen pandas pin (1.4 would need to
 be compiled from source on macs)

---
 requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/requirements.txt b/requirements.txt
index 76071439..2837c36e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,7 +23,7 @@ flask==2.*
 pysbd==0.3.4
 # deps for notebooks
 umap-learn==0.5.*
-pandas==1.4.*
+pandas>=1.4,<2.0
 # deps for training
 matplotlib==3.7.*
 # coqui stack

From 59f85a7122d85cb4a9d5afc0318774bf48d691a4 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 27 Sep 2023 00:54:19 +0300
Subject: [PATCH 20/37] Remove duplicate code from xtts.tokenizer

---
 TTS/tts/layers/xtts/tokenizer.py | 37 --------------------------------
 1 file changed, 37 deletions(-)

diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index 0fad8133..8dd81fac 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -171,17 +171,6 @@ def multilingual_cleaners(text, lang):
     return text
 
 
-def english_cleaners(text):
-    """Pipeline for English text, including number and abbreviation expansion."""
-    text = convert_to_ascii(text)
-    text = lowercase(text)
-    text = expand_numbers(text)
-    text = expand_abbreviations(text)
-    text = collapse_whitespace(text)
-    text = text.replace('"', "")
-    return text
-
-
 def remove_extraneous_punctuation(word):
     replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"}
     replace = re.compile(
@@ -195,32 +184,6 @@ def remove_extraneous_punctuation(word):
     return word
 
 
-def expand_numbers(text):
-    return normalize_numbers(text)
-
-
-def lowercase(text):
-    return text.lower()
-
-
-_whitespace_re = re.compile(r"\s+")
-
-
-def collapse_whitespace(text):
-    return re.sub(_whitespace_re, " ", text)
-
-
-def convert_to_ascii(text):
-    return unidecode(text)
-
-
-def basic_cleaners(text):
-    """Basic pipeline that lowercases and collapses whitespace without transliteration."""
-    text = lowercase(text)
-    text = collapse_whitespace(text)
-    return text
-
-
 def arabic_cleaners(text):
     text = lowercase(text)
     text = collapse_whitespace(text)

From 09e14e68db08bba569c3c45e1a57de965e2c2d47 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 27 Sep 2023 01:03:03 +0300
Subject: [PATCH 21/37] Remove duplicate get_named_beta_schedules

---
 TTS/tts/layers/tortoise/diffusion.py | 25 -------------------------
 TTS/tts/layers/xtts/diffusion.py     | 25 -------------------------
 2 files changed, 50 deletions(-)

diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py
index eb9e90df..cb350af7 100644
--- a/TTS/tts/layers/tortoise/diffusion.py
+++ b/TTS/tts/layers/tortoise/diffusion.py
@@ -1085,31 +1085,6 @@ class GaussianDiffusion:
         }
 
 
-def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
-    """
-    Get a pre-defined beta schedule for the given name.
-
-    The beta schedule library consists of beta schedules which remain similar
-    in the limit of num_diffusion_timesteps.
-    Beta schedules may be added, but should not be removed or changed once
-    they are committed to maintain backwards compatibility.
-    """
-    if schedule_name == "linear":
-        # Linear schedule from Ho et al, extended to work for any number of
-        # diffusion steps.
-        scale = 1000 / num_diffusion_timesteps
-        beta_start = scale * 0.0001
-        beta_end = scale * 0.02
-        return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
-    elif schedule_name == "cosine":
-        return betas_for_alpha_bar(
-            num_diffusion_timesteps,
-            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
-        )
-    else:
-        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
-
-
 class SpacedDiffusion(GaussianDiffusion):
     """
     A diffusion process which can skip steps in a base diffusion process.
diff --git a/TTS/tts/layers/xtts/diffusion.py b/TTS/tts/layers/xtts/diffusion.py
index a0b93add..37665bc6 100644
--- a/TTS/tts/layers/xtts/diffusion.py
+++ b/TTS/tts/layers/xtts/diffusion.py
@@ -1170,31 +1170,6 @@ class GaussianDiffusion:
         }
 
 
-def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
-    """
-    Get a pre-defined beta schedule for the given name.
-
-    The beta schedule library consists of beta schedules which remain similar
-    in the limit of num_diffusion_timesteps.
-    Beta schedules may be added, but should not be removed or changed once
-    they are committed to maintain backwards compatibility.
-    """
-    if schedule_name == "linear":
-        # Linear schedule from Ho et al, extended to work for any number of
-        # diffusion steps.
-        scale = 1000 / num_diffusion_timesteps
-        beta_start = scale * 0.0001
-        beta_end = scale * 0.02
-        return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
-    elif schedule_name == "cosine":
-        return betas_for_alpha_bar(
-            num_diffusion_timesteps,
-            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
-        )
-    else:
-        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
-
-
 class SpacedDiffusion(GaussianDiffusion):
     """
     A diffusion process which can skip steps in a base diffusion process.

From 861c68b0b887e26a3c7a74a83e789310343a54d0 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 27 Sep 2023 01:07:02 +0300
Subject: [PATCH 22/37] Rename misnamed setter

---
 TTS/tts/models/delightful_tts.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py
index c0a00c66..0ee99930 100644
--- a/TTS/tts/models/delightful_tts.py
+++ b/TTS/tts/models/delightful_tts.py
@@ -726,8 +726,8 @@ class DelightfulTTS(BaseTTSE2E):
     def pitch_std(self):
         return self.acoustic_model.pitch_std
 
-    @pitch_mean.setter
-    def pitch_std(self, value):  # pylint: disable=function-redefined
+    @pitch_std.setter
+    def pitch_std(self, value):
         self.acoustic_model.pitch_std = value
 
     @property

From 33a7c722f6fa7b9aaa8cb605c5c0480007164cf9 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 27 Sep 2023 01:10:44 +0300
Subject: [PATCH 23/37] Merge duplicate on_train_step_start functions in
 delightful_tts

---
 TTS/tts/models/delightful_tts.py | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py
index 0ee99930..b1cf886b 100644
--- a/TTS/tts/models/delightful_tts.py
+++ b/TTS/tts/models/delightful_tts.py
@@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
         scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
         return [scheduler_D, scheduler_G]
 
-    def on_train_step_start(self, trainer):
-        """Schedule binary loss weight."""
-        self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
-
     def on_epoch_end(self, trainer):  # pylint: disable=unused-argument
         # stop updating mean and var
         # TODO: do the same for F0
@@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
         Args:
             trainer (Trainer): Trainer object.
         """
+        self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
         self.train_disc = (  # pylint: disable=attribute-defined-outside-init
             trainer.total_steps_done >= self.config.steps_to_start_discriminator
         )

From 0dbe7cbcc4df767d63584684b428172b7bf846b2 Mon Sep 17 00:00:00 2001
From: Aarni Koskela <akx@iki.fi>
Date: Wed, 27 Sep 2023 01:08:10 +0300
Subject: [PATCH 24/37] Remove duplicate convert_pad_shape

---
 TTS/vc/modules/freevc/commons.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py
index 5684a88e..e799cc2a 100644
--- a/TTS/vc/modules/freevc/commons.py
+++ b/TTS/vc/modules/freevc/commons.py
@@ -116,12 +116,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
     return acts
 
 
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
 def shift_1d(x):
     x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
     return x

From 4c3c11c958344dd4b91d0d238a4d7d7f35803f86 Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 29 Sep 2023 08:40:57 -0300
Subject: [PATCH 25/37] Tortoise inference fix and fix zoo unit tests (#3010)

---
 .github/workflows/zoo_tests_tortoise.yml |   52 ++
 TTS/tts/layers/tortoise/tokenizer.py     |    6 +-
 TTS/tts/layers/xtts/gpt_encoder_eren.py  |  658 --------------
 TTS/tts/layers/xtts/gpt_encoder_old.py   | 1057 ----------------------
 tests/zoo_tests/test_models.py           |   63 +-
 5 files changed, 108 insertions(+), 1728 deletions(-)
 create mode 100644 .github/workflows/zoo_tests_tortoise.yml
 delete mode 100644 TTS/tts/layers/xtts/gpt_encoder_eren.py
 delete mode 100644 TTS/tts/layers/xtts/gpt_encoder_old.py

diff --git a/.github/workflows/zoo_tests_tortoise.yml b/.github/workflows/zoo_tests_tortoise.yml
new file mode 100644
index 00000000..31442877
--- /dev/null
+++ b/.github/workflows/zoo_tests_tortoise.yml
@@ -0,0 +1,52 @@
+name: zoo-tests-tortoise
+
+on:
+  push:
+    branches:
+      - main
+  pull_request:
+    types: [opened, synchronize, reopened]
+jobs:
+  check_skip:
+    runs-on: ubuntu-latest
+    if: "! contains(github.event.head_commit.message, '[ci skip]')"
+    steps:
+      - run: echo "${{ github.event.head_commit.message }}"
+
+  test:
+    runs-on: ubuntu-latest
+    strategy:
+      fail-fast: false
+      matrix:
+        python-version: [3.9, "3.10", "3.11"]
+        experimental: [false]
+    steps:
+      - uses: actions/checkout@v3
+      - name: Set up Python ${{ matrix.python-version }}
+        uses: actions/setup-python@v4
+        with:
+          python-version: ${{ matrix.python-version }}
+          architecture: x64
+          cache: 'pip'
+          cache-dependency-path: 'requirements*'
+      - name: check OS
+        run: cat /etc/os-release
+      - name: set ENV
+        run: export TRAINER_TELEMETRY=0
+      - name: Install dependencies
+        run: |
+          sudo apt-get update
+          sudo apt-get install -y git make gcc
+          sudo apt-get install espeak espeak-ng
+          make system-deps
+      - name: Install/upgrade Python setup deps
+        run: python3 -m pip install --upgrade pip setuptools wheel
+      - name: Replace scarf urls
+        run: |
+          sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
+      - name: Install TTS
+        run: |
+          python3 -m pip install .[all]
+          python3 setup.py egg_info
+      - name: Unit tests
+        run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise
diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py
index 3969b2cc..d243d655 100644
--- a/TTS/tts/layers/tortoise/tokenizer.py
+++ b/TTS/tts/layers/tortoise/tokenizer.py
@@ -5,9 +5,13 @@ from tokenizers import Tokenizer
 
 from TTS.tts.utils.text.cleaners import english_cleaners
 
+DEFAULT_VOCAB_FILE = os.path.join(
+    os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
+)
+
 
 class VoiceBpeTokenizer:
-    def __init__(self, vocab_file=None, vocab_str=None):
+    def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None):
         self.tokenizer = None
         if vocab_file is not None:
             self.tokenizer = Tokenizer.from_file(vocab_file)
diff --git a/TTS/tts/layers/xtts/gpt_encoder_eren.py b/TTS/tts/layers/xtts/gpt_encoder_eren.py
deleted file mode 100644
index b5e7158d..00000000
--- a/TTS/tts/layers/xtts/gpt_encoder_eren.py
+++ /dev/null
@@ -1,658 +0,0 @@
-import functools
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
-from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
-
-
-def null_position_embeddings(range, dim):
-    return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
-
-
-class GPT2InferenceModel(GPT2PreTrainedModel):
-    """Override GPT2LMHeadModel to allow for prefix conditioning."""
-
-    def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
-        super().__init__(config)
-        self.transformer = gpt
-        self.pos_embedding = pos_emb
-        self.embeddings = embeddings
-        self.final_norm = norm
-        self.lm_head = nn.Sequential(norm, linear)
-        self.kv_cache = kv_cache
-
-    def store_prefix_emb(self, prefix_emb):
-        self.cached_prefix_emb = prefix_emb
-
-    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
-        token_type_ids = kwargs.get("token_type_ids", None)  # usually None
-        if not self.kv_cache:
-            past_key_values = None
-
-        # only last token for inputs_ids if past is defined in kwargs
-        if past_key_values is not None:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-            if token_type_ids is not None:
-                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values is not None:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past_key_values,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-            "token_type_ids": token_type_ids,
-        }
-
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        token_type_ids=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        encoder_hidden_states=None,
-        encoder_attention_mask=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        assert self.cached_prefix_emb is not None
-        assert inputs_embeds is None  # Not supported by this inference model.
-        assert labels is None  # Training not supported by this inference model.
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
-
-        # Create embedding
-        prefix_len = self.cached_prefix_emb.shape[1]
-        if input_ids.shape[1] != 1:
-            gen_inputs = input_ids[:, prefix_len:]
-            gen_emb = self.embeddings(gen_inputs)
-            gen_emb = gen_emb + self.pos_embedding(gen_emb)
-            if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
-                prefix_emb = self.cached_prefix_emb.repeat_interleave(
-                    gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
-                )
-            else:
-                prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
-            emb = torch.cat([prefix_emb, gen_emb], dim=1)
-        else:
-            emb = self.embeddings(input_ids)
-            emb = emb + self.pos_embedding.get_fixed_embedding(
-                attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
-            )
-        transformer_outputs = self.transformer(
-            inputs_embeds=emb,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            token_type_ids=token_type_ids,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            encoder_hidden_states=encoder_hidden_states,
-            encoder_attention_mask=encoder_attention_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-        hidden_states = transformer_outputs[0]
-        lm_logits = self.lm_head(hidden_states)
-
-        if not return_dict:
-            return (lm_logits,) + transformer_outputs[1:]
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=None,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-            cross_attentions=transformer_outputs.cross_attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past, beam_idx):
-        """
-        This function is used to re-order the :obj:`past_key_values` cache if
-        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
-        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )
-
-
-class LearnedPositionEmbeddings(nn.Module):
-    def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
-        super().__init__()
-        self.emb = nn.Embedding(seq_len, model_channels)
-        nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
-        self.relative = relative
-
-    def forward(self, x):
-        seq_len = x.shape[1]
-        if self.relative:
-            start = torch.randint(seq_len, (1,), device=x.device).item()
-            positions = torch.arange(start, start + seq_len, device=x.device)
-        else:
-            positions = torch.arange(seq_len, device=x.device)
-        return self.emb(positions)
-
-    def get_fixed_embedding(self, ind, dev):
-        return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
-
-
-def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
-    """
-    Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
-
-    Args:
-        layers (int): Number of layers in the GPT-2 model.
-        model_channels (int): Dimension of the GPT-2 model.
-        heads (int): Number of heads in the GPT-2 model.
-        max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
-        max_text_seq_len (int): Maximum sequence length for the text.
-        max_prompt_len (int): Maximum length of the prompt.
-        checkpointing (bool): Whether to use gradient checkpointing.
-
-    Returns:
-        gpt (GPT2Model): GPT-2 model.
-        mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
-        text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
-    """
-    gpt_config = GPT2Config(
-        vocab_size=123,
-        n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
-        n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
-        n_embd=model_channels,
-        n_layer=layers,
-        n_head=heads,
-        gradient_checkpointing=checkpointing,
-        use_cache=not checkpointing,
-    )
-    gpt = GPT2Model(gpt_config)
-
-    del gpt.wpe
-    del gpt.wte
-
-    gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
-
-    audio_pos_emb = (
-        LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
-        if max_mel_seq_len != -1
-        else functools.partial(null_position_embeddings, dim=model_channels)
-    )
-    text_pos_emb = (
-        LearnedPositionEmbeddings(max_text_seq_len, model_channels)
-        if max_mel_seq_len != -1
-        else functools.partial(null_position_embeddings, dim=model_channels)
-    )
-
-    return gpt, audio_pos_emb, text_pos_emb
-
-
-class XTTSGPTEncoder(nn.Module):
-    """XTTS GPT Encoder model implementation.
-    Args:
-        start_text_token (int): Index of the start token in the text vocabulary.
-        stop_text_token (int): Index of the stop token in the text vocabulary.
-        n_layers (int): Number of layers in the GPT-2 model.
-        n_model_channels (int): Dimension of the GPT-2 model.
-        n_heads (int): Number of heads in the GPT-2 model.
-        max_text_tokens (int): Maximum number of text tokens.
-        max_audio_tokens (int): Maximum number of audio tokens.
-        max_prompt_tokens (int): Maximum number of prompt tokens.
-        audio_len_compression (int): Compression factor for the audio length.
-        number_text_tokens (int): Number of text tokens.
-        number_audio_codes (int): Number of audio codes.
-        start_mel_token (int): Index of the start token in the mel code vocabulary.
-        stop_mel_token (int): Index of the stop token in the mel code vocabulary.
-        checkpointing (bool): Whether or not to use gradient checkpointing at training.
-    """
-
-    _inference_flag = False
-
-    def __init__(
-        self,
-        start_text_token=261,
-        stop_text_token=0,
-        n_layers=8,
-        n_model_channels=512,
-        n_heads=8,
-        max_text_tokens=120,
-        max_audio_tokens=250,
-        max_prompt_tokens=70,
-        audio_len_compression=1024,
-        number_text_tokens=256,
-        number_audio_codes=8194,
-        start_mel_token=8192,
-        stop_mel_token=8193,
-        checkpointing=True,
-        label_smoothing=0.0,
-    ):
-        super().__init__()
-
-        self.label_smoothing = label_smoothing
-        self.number_text_tokens = number_text_tokens
-        self.start_text_token = start_text_token
-        self.stop_text_token = stop_text_token
-        self.number_audio_codes = number_audio_codes
-        self.start_mel_token = start_mel_token
-        self.stop_mel_token = stop_mel_token
-        self.start_prompt_token = start_mel_token
-        self.stop_prompt_token = stop_mel_token
-        self.n_layers = n_layers
-        self.n_heads = n_heads
-        self.n_model_channels = n_model_channels
-        self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
-        self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
-        self.max_prompt_tokens = max_prompt_tokens
-        self.audio_len_compression = audio_len_compression
-
-        # embedding layers
-        self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
-        self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
-        self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
-        self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
-
-        # initialize the GPT-2 model
-        (
-            self.gpt,
-            self.audio_pos_embedding,
-            self.text_pos_embedding,
-        ) = init_gpt(
-            n_layers,
-            n_model_channels,
-            n_heads,
-            self.max_audio_tokens,
-            self.max_text_tokens,
-            self.max_prompt_tokens,
-            checkpointing,
-        )
-
-        # output layers
-        self.final_norm = nn.LayerNorm(n_model_channels)
-        self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
-        self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
-
-    def get_grad_norm_parameter_groups(self):
-        return {
-            "conditioning_encoder": list(self.conditioning_encoder.parameters()),
-            "gpt": list(self.gpt.parameters()),
-            "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
-        }
-
-    def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
-        self._inference_flag = True
-        seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
-        gpt_config = GPT2Config(
-            vocab_size=self.max_audio_tokens,
-            n_positions=seq_length,
-            n_ctx=seq_length,
-            n_embd=self.n_model_channels,
-            n_layer=self.n_layers,
-            n_head=self.n_heads,
-            gradient_checkpointing=False,
-            use_cache=True,
-        )
-        self.inference_model = GPT2InferenceModel(
-            gpt_config,
-            self.gpt,
-            self.audio_pos_embedding,
-            self.audio_embedding,
-            self.final_norm,
-            self.mel_head,
-            kv_cache=kv_cache,
-        )
-        self.gpt.wte = self.audio_embedding
-
-    def set_inputs_and_targets(self, input, start_token, stop_token):
-        inp = F.pad(input, (1, 0), value=start_token)
-        tar = F.pad(input, (0, 1), value=stop_token)
-        return inp, tar
-
-    def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
-        # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
-        for b in range(len(audio_token_lens)):
-            actual_end = audio_token_lens[b]
-            if actual_end < audio_tokens.shape[-1]:
-                audio_tokens[b, actual_end:] = self.stop_mel_token
-        return audio_tokens
-
-    def get_logits(
-        self,
-        speech_conditioning_inputs,
-        first_inputs,
-        first_head,
-        second_inputs=None,
-        second_head=None,
-        prompt=None,
-        get_attns=False,
-        return_latent=False,
-        attn_mask_text=None,
-        attn_mask_mel=None,
-    ):
-        if prompt is not None and speech_conditioning_inputs is not None:
-            offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat(
-                    [speech_conditioning_inputs, prompt, first_inputs, second_inputs],
-                    dim=1,
-                )
-            else:
-                emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
-        elif speech_conditioning_inputs is not None:
-            offset = speech_conditioning_inputs.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
-            else:
-                emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
-        elif prompt is not None:
-            offset = prompt.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
-            else:
-                emb = torch.cat([prompt, first_inputs], dim=1)
-
-        # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
-        attn_mask = None
-        if attn_mask_text is not None:
-            attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
-            if prompt is not None:
-                attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
-                attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
-
-        gpt_out = self.gpt(
-            inputs_embeds=emb,
-            return_dict=True,
-            output_attentions=get_attns,
-            attention_mask=attn_mask,
-        )
-
-        if get_attns:
-            return gpt_out.attentions
-
-        enc = gpt_out.last_hidden_state[:, offset:]
-        enc = self.final_norm(enc)
-
-        if return_latent:
-            return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
-
-        first_logits = enc[:, : first_inputs.shape[1]]
-        first_logits = first_head(first_logits)
-        first_logits = first_logits.permute(0, 2, 1)
-        if second_inputs is not None:
-            second_logits = enc[:, -second_inputs.shape[1] :]
-            second_logits = second_head(second_logits)
-            second_logits = second_logits.permute(0, 2, 1)
-            return first_logits, second_logits
-        else:
-            return first_logits
-
-    def get_conditioning(self, speech_conditioning_input):
-        speech_conditioning_input = (
-            speech_conditioning_input.unsqueeze(1)
-            if len(speech_conditioning_input.shape) == 3
-            else speech_conditioning_input
-        )
-        conds = []
-        for j in range(speech_conditioning_input.shape[1]):
-            conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-        conds = torch.stack(conds, dim=1)
-        conds = conds.mean(dim=1)
-        return conds
-
-    def get_prompts(self, prompt_codes):
-        prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
-        prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
-        return prompt
-
-    def forward(
-        self,
-        text_inputs,
-        text_lengths,
-        audio_codes,
-        wav_lengths,
-        prompt_codes,
-        return_attentions=False,
-        return_latent=False,
-    ):
-        max_text_len = text_lengths.max()
-
-        # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
-        # Like [..., 186, 45, 45, 83] where actually it should end with 186.
-        # We take last 3 codes to prevent abrupt ending of the audio.
-        # TODO: This is might need some testing.
-        mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
-
-        # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
-        max_mel_len = mel_lengths.max()
-
-        if max_mel_len > audio_codes.shape[-1]:
-            audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
-
-        # silence aware lengths, skip the silence tokens at the end of the mel codes.
-        silence = True
-        for idx, l in enumerate(mel_lengths):
-            length = l.item()
-            while silence:
-                if audio_codes[idx, length - 1] != 83:
-                    break
-                length -= 1
-            mel_lengths[idx] = length
-
-        # Lovely assertions
-        assert (
-            max_mel_len <= audio_codes.shape[-1]
-        ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
-        assert (
-            max_text_len <= text_inputs.shape[-1]
-        ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
-
-        # Append stop token to text inputs
-        text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
-
-        # Append silence token to mel codes
-        audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
-
-        # Pad mel codes with STOP_MEL_TOKEN
-        audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
-
-        # Compute speech conditioning input
-        conds = None
-        if speech_conditioning_input is not None:
-            if not return_latent:
-                # Compute speech conditioning input
-                speech_conditioning_input = (
-                    speech_conditioning_input.unsqueeze(1)
-                    if len(speech_conditioning_input.shape) == 3
-                    else speech_conditioning_input
-                )
-
-                conds = []
-                for j in range(speech_conditioning_input.shape[1]):
-                    conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-                conds = torch.stack(conds, dim=1)
-                if self.average_conditioning_embeddings:
-                    conds = conds.mean(dim=1).unsqueeze(1)
-            else:
-                # already computed
-                conds = speech_conditioning_input.unsqueeze(1)
-
-        # Build input and target tensors
-        # Prepend start token to inputs and append stop token to targets
-        text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
-        audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
-
-        # Set attn_mask
-        attn_mask_text = None
-        attn_mask_mel = None
-        if not return_latent:
-            attn_mask_text = torch.ones(
-                text_inputs.shape[0],
-                text_inputs.shape[1],
-                dtype=torch.bool,
-                device=text_inputs.device,
-            )
-            attn_mask_mel = torch.ones(
-                audio_codes.shape[0],
-                audio_codes.shape[1],
-                dtype=torch.bool,
-                device=audio_codes.device,
-            )
-
-            for idx, l in enumerate(text_lengths):
-                attn_mask_text[idx, l + 1 :] = 0.0
-
-            for idx, l in enumerate(mel_lengths):
-                attn_mask_mel[idx, l + 1 :] = 0.0
-
-        # Compute text embeddings + positional embeddings
-        # print(" > text input latent:", text_inputs)
-        text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        # Compute mel embeddings + positional embeddings
-        audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
-
-        # Compute prompt embeddings + positional embeddings
-        prompt = self.get_prompts(prompt_codes)
-
-        # prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
-        prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
-
-        # dropout prompt embeddings
-        prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
-
-        # Get logits
-        sub = -4  # don't ask me why 😄
-        if self.training:
-            sub = -1
-        _, audio_logits = self.get_logits(
-            conds,
-            text_emb,
-            self.text_head,
-            audio_emb,
-            self.mel_head,
-            prompt=prompt_emb,
-            get_attns=return_attentions,
-            return_latent=return_latent,
-            attn_mask_text=attn_mask_text,
-            attn_mask_mel=attn_mask_mel,
-        )
-        return audio_logits[:, :sub]  # sub to prevent bla.
-
-    def compute_embeddings(
-        self,
-        speech_conditioning_latent,
-        text_inputs,
-        input_tokens=None,
-        prompt_codes=None,
-        pad_input_text=False,
-    ):
-        """Compute all the embeddings needed for inference."""
-        if pad_input_text and text_inputs.shape[1] < 250:
-            text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
-        else:
-            text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
-        text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
-
-        emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        print(" > Text inputs:", text_inputs)
-        if prompt_codes is not None:
-            prompt_codes = self.get_prompts(prompt_codes)
-            # prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
-            prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
-
-            print(" > Prompt inputs:", prompt_codes)
-            print(" > Prompt inputs shape:", prompt_codes.shape)
-            emb = torch.cat([prompt_emb, emb], dim=1)
-
-        if speech_conditioning_latent is not None:
-            conds = speech_conditioning_latent.unsqueeze(1)
-            emb = torch.cat([conds, emb], dim=1)
-
-        self.inference_model.store_prefix_emb(emb)
-
-        fake_inputs = torch.full(
-            (
-                emb.shape[0],
-                emb.shape[1] + 1,  # +1 for the start_mel_token
-            ),
-            fill_value=1,
-            dtype=torch.long,
-            device=text_inputs.device,
-        )
-        fake_inputs[:, -1] = self.start_mel_token
-
-        if input_tokens is not None:
-            fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
-        return fake_inputs
-
-    def inference(
-        self,
-        text_inputs,
-        input_tokens=None,
-        prompt_codes=None,
-        pad_input_text=False,
-        **hf_generate_kwargs,
-    ):
-        if pad_input_text and text_inputs.shape[1] < 250:
-            text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
-        else:
-            text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
-        text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
-
-        emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        if prompt_codes is not None:
-            prompt_codes = self.get_prompts(prompt_codes)
-            prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
-            emb = torch.cat([prompt_emb, emb], dim=1)
-
-        self.inference_model.store_prefix_emb(emb)
-
-        fake_inputs = torch.full(
-            (
-                emb.shape[0],
-                emb.shape[1] + 1,  # +1 for the start_mel_token
-            ),
-            fill_value=1,
-            dtype=torch.long,
-            device=text_inputs.device,
-        )
-        fake_inputs[:, -1] = self.start_mel_token
-
-        if input_tokens is not None:
-            fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
-
-        gen = self.inference_model.generate(
-            fake_inputs,
-            bos_token_id=self.start_mel_token,
-            pad_token_id=self.stop_mel_token,
-            eos_token_id=self.stop_mel_token,
-            max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
-            **hf_generate_kwargs,
-        )
-        if "return_dict_in_generate" in hf_generate_kwargs:
-            return gen.sequences[:, fake_inputs.shape[1] :], gen
-        return gen[:, fake_inputs.shape[1] :]
diff --git a/TTS/tts/layers/xtts/gpt_encoder_old.py b/TTS/tts/layers/xtts/gpt_encoder_old.py
deleted file mode 100644
index 46739aa2..00000000
--- a/TTS/tts/layers/xtts/gpt_encoder_old.py
+++ /dev/null
@@ -1,1057 +0,0 @@
-import functools
-import math
-import random
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-try:
-    import deepspeed
-    from deepspeed.ops.transformer.inference import DeepSpeedTransformerInferenceKernel
-except ImportError:
-    pass
-
-import dlas.codes.torch_intermediary as ml
-from dlas.codes.models.arch_util import AttentionBlock
-from dlas.codes.trainer.networks import register_model
-from dlas.codes.utils.transformers.stream_generator import init_stream_support
-from dlas.codes.utils.util import opt_get
-from transformers import GPT2Config, GPT2PreTrainedModel
-from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
-
-init_stream_support()
-
-
-def null_position_embeddings(range, dim):
-    return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
-
-
-class ResBlock(nn.Module):
-    """
-    Basic residual convolutional block that uses GroupNorm.
-    """
-
-    def __init__(self, chan):
-        super().__init__()
-        self.net = nn.Sequential(
-            nn.Conv1d(chan, chan, kernel_size=3, padding=1),
-            nn.GroupNorm(chan // 8, chan),
-            nn.ReLU(),
-            nn.Conv1d(chan, chan, kernel_size=3, padding=1),
-            nn.GroupNorm(chan // 8, chan),
-        )
-
-    def forward(self, x):
-        return F.relu(self.net(x) + x)
-
-
-class GPT2InferenceModel(GPT2PreTrainedModel):
-    """Override GPT2LMHeadModel to allow for prefix conditioning."""
-
-    def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
-        super().__init__(config)
-        self.transformer = gpt
-        self.pos_embedding = pos_emb
-        self.embeddings = embeddings
-        self.final_norm = norm
-        self.lm_head = nn.Sequential(norm, linear)
-        self.kv_cache = kv_cache
-
-    def store_prefix_emb(self, prefix_emb):
-        self.cached_prefix_emb = prefix_emb
-
-    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
-        token_type_ids = kwargs.get("token_type_ids", None)  # usually None
-        if not self.kv_cache:
-            past_key_values = None
-
-        # only last token for inputs_ids if past is defined in kwargs
-        if past_key_values is not None:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-            if token_type_ids is not None:
-                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values is not None:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past_key_values,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-            "token_type_ids": token_type_ids,
-        }
-
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        token_type_ids=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        encoder_hidden_states=None,
-        encoder_attention_mask=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        assert self.cached_prefix_emb is not None
-        assert inputs_embeds is None  # Not supported by this inference model.
-        assert labels is None  # Training not supported by this inference model.
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
-
-        # Create embedding
-        prefix_len = self.cached_prefix_emb.shape[1]
-        if input_ids.shape[1] != 1:
-            gen_inputs = input_ids[:, prefix_len:]
-            gen_emb = self.embeddings(gen_inputs)
-            gen_emb = gen_emb + self.pos_embedding(gen_emb)
-            if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
-                prefix_emb = self.cached_prefix_emb.repeat_interleave(
-                    gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
-                )
-            else:
-                prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
-            emb = torch.cat([prefix_emb, gen_emb], dim=1)
-        else:
-            emb = self.embeddings(input_ids)
-            emb = emb + self.pos_embedding.get_fixed_embedding(
-                attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
-            )
-        transformer_outputs = self.transformer(
-            inputs_embeds=emb,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            token_type_ids=token_type_ids,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            encoder_hidden_states=encoder_hidden_states,
-            encoder_attention_mask=encoder_attention_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-        hidden_states = transformer_outputs[0]
-        lm_logits = self.lm_head(hidden_states)
-
-        if not return_dict:
-            return (lm_logits,) + transformer_outputs[1:]
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=None,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-            cross_attentions=transformer_outputs.cross_attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past, beam_idx):
-        """
-        This function is used to re-order the :obj:`past_key_values` cache if
-        :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
-        called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )
-
-
-class ConditioningEncoder(nn.Module):
-    def __init__(
-        self,
-        spec_dim,
-        embedding_dim,
-        attn_blocks=6,
-        num_attn_heads=4,
-        do_checkpointing=False,
-        mean=False,
-    ):
-        super().__init__()
-        attn = []
-        self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
-        for a in range(attn_blocks):
-            attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
-        self.attn = nn.Sequential(*attn)
-        self.dim = embedding_dim
-        self.do_checkpointing = do_checkpointing
-        self.mean = mean
-
-    def forward(self, x):
-        h = self.init(x)
-        h = self.attn(h)
-        if self.mean:
-            return h.mean(dim=2)
-        else:
-            return h[:, :, 0]
-
-
-class LearnedPositionEmbeddings(nn.Module):
-    def __init__(self, seq_len, model_dim, init=0.02, relative=False):
-        super().__init__()
-        # nn.Embedding
-        self.emb = torch.nn.Embedding(seq_len, model_dim)
-        # Initializing this way is standard for GPT-2
-        self.emb.weight.data.normal_(mean=0.0, std=init)
-        self.relative = relative
-        self.seq_len = seq_len
-
-    def forward(self, x):
-        sl = x.shape[1]
-        if self.relative:
-            start = random.randint(sl, self.seq_len) - sl
-            return self.emb(torch.arange(start, start + sl, device=x.device))
-        else:
-            return self.emb(torch.arange(0, sl, device=x.device))
-
-    def get_fixed_embedding(self, ind, dev):
-        return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
-
-
-def build_hf_gpt_transformer(
-    layers,
-    model_dim,
-    heads,
-    max_mel_seq_len,
-    max_text_seq_len,
-    max_prompt_len,
-    checkpointing,
-):
-    """
-    GPT-2 implemented by the HuggingFace library.
-    """
-    from transformers import GPT2Config, GPT2Model
-
-    gpt_config = GPT2Config(
-        vocab_size=256,  # Unused.
-        n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
-        n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
-        n_embd=model_dim,
-        n_layer=layers,
-        n_head=heads,
-        gradient_checkpointing=checkpointing,
-        use_cache=not checkpointing,
-    )
-    gpt = GPT2Model(gpt_config)
-    # Override the built in positional embeddings
-    del gpt.wpe
-    gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
-    # Built-in token embeddings are unused.
-    del gpt.wte
-
-    # def _attn(self, query, key, value, attention_mask=None, head_mask=None):
-    #     attn_output = torch.nn.functional.scaled_dot_product_attention(
-    #         query, key, value, dropout_p=self.attn_dropout.p, is_causal=True
-    #     )
-    #     return attn_output, None
-
-    # for i in range(len(gpt.h)):
-    #     gpt.h[i].attn._attn = types.MethodType(
-    #         _attn, gpt.h[i].attn
-    #     )
-
-    mel_pos_emb = (
-        LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
-        if max_mel_seq_len != -1
-        else functools.partial(null_position_embeddings, dim=model_dim)
-    )
-    text_pos_emb = (
-        LearnedPositionEmbeddings(max_text_seq_len, model_dim)
-        if max_mel_seq_len != -1
-        else functools.partial(null_position_embeddings, dim=model_dim)
-    )
-    # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
-    return gpt, mel_pos_emb, text_pos_emb, None, None
-
-
-class MelEncoder(nn.Module):
-    def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
-        super().__init__()
-        self.channels = channels
-        self.encoder = nn.Sequential(
-            nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
-            nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
-            nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
-            nn.GroupNorm(channels // 16, channels // 2),
-            nn.ReLU(),
-            nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
-            nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
-            nn.GroupNorm(channels // 8, channels),
-            nn.ReLU(),
-            nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
-        )
-        self.reduction = 4
-
-    def forward(self, x):
-        for e in self.encoder:
-            x = e(x)
-        return x.permute(0, 2, 1)
-
-
-class UnifiedVoice(nn.Module):
-    def __init__(
-        self,
-        start_text_token=261,
-        stop_text_token=0,
-        layers=8,
-        model_dim=512,
-        heads=8,
-        max_text_tokens=120,
-        max_mel_tokens=250,
-        max_prompt_tokens=70,
-        max_conditioning_inputs=1,
-        mel_length_compression=1024,
-        number_text_tokens=256,
-        number_mel_codes=8194,
-        start_mel_token=8192,
-        stop_mel_token=8193,
-        train_solo_embeddings=False,
-        use_mel_codes_as_input=True,
-        checkpointing=True,
-        average_conditioning_embeddings=False,
-        freeze_everything_but_position_embeddings=False,
-        freeze_conditioning_encoder=False,
-        tortoise_compat=True,
-        label_smoothing=0.0,
-    ):
-        """
-        Args:
-            layers: Number of layers in transformer stack.
-            model_dim: Operating dimensions of the transformer
-            heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
-            max_text_tokens: Maximum number of text tokens that will be encountered by model.
-            max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
-            max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
-            mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
-            number_text_tokens:
-            start_text_token:
-            stop_text_token:
-            number_mel_codes:
-            start_mel_token:
-            stop_mel_token:
-            train_solo_embeddings:
-            use_mel_codes_as_input:
-            checkpointing:
-            average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
-        """
-        super().__init__()
-
-        self.label_smoothing = label_smoothing
-        self.number_text_tokens = number_text_tokens
-        self.start_text_token = start_text_token
-        self.stop_text_token = stop_text_token
-        self.number_mel_codes = number_mel_codes
-        self.start_mel_token = start_mel_token
-        self.stop_mel_token = stop_mel_token
-        self.start_prompt_token = start_mel_token
-        self.stop_prompt_token = stop_mel_token
-        self.layers = layers
-        self.heads = heads
-        self.model_dim = model_dim
-        self.max_conditioning_inputs = max_conditioning_inputs
-        self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
-        self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
-        self.max_prompt_tokens = max_prompt_tokens
-        self.mel_length_compression = mel_length_compression
-        # self.conditioning_encoder = ConditioningEncoder(
-        #     80, model_dim, num_attn_heads=heads
-        # )
-        self.average_conditioning_embeddings = average_conditioning_embeddings
-        self.tortoise_compat = tortoise_compat  # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
-        # nn.Embedding
-        self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
-        if use_mel_codes_as_input:
-            # nn.Embedding
-            self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
-        else:
-            self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
-        (
-            self.gpt,
-            self.mel_pos_embedding,
-            self.text_pos_embedding,
-            self.mel_layer_pos_embedding,
-            self.text_layer_pos_embedding,
-        ) = build_hf_gpt_transformer(
-            layers,
-            model_dim,
-            heads,
-            self.max_mel_tokens,
-            self.max_text_tokens,
-            self.max_prompt_tokens,
-            checkpointing,
-        )
-        if train_solo_embeddings:
-            self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
-            self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
-        else:
-            self.mel_solo_embedding = 0
-            self.text_solo_embedding = 0
-
-        self.final_norm = nn.LayerNorm(model_dim)
-        self.text_head = ml.Linear(model_dim, self.number_text_tokens)
-        self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
-
-        # Initialize the embeddings per the GPT-2 scheme
-        embeddings = [self.text_embedding]
-        if use_mel_codes_as_input:
-            embeddings.append(self.mel_embedding)
-        for module in embeddings:
-            module.weight.data.normal_(mean=0.0, std=0.02)
-
-        if freeze_conditioning_encoder:
-            print(" > Freezing conditioning encoder.")
-            for p in self.conditioning_encoder.parameters():
-                p.requires_grad = False
-                p.DO_NOT_TRAIN = True
-
-        if freeze_everything_but_position_embeddings:
-            for p in self.parameters():
-                p.requires_grad = False
-                p.DO_NOT_TRAIN = True
-            for m in [self.mel_pos_embedding, self.text_pos_embedding]:
-                for p in m.parameters():
-                    del p.DO_NOT_TRAIN
-                    p.requires_grad = True
-
-    def get_grad_norm_parameter_groups(self):
-        return {
-            "conditioning_encoder": list(self.conditioning_encoder.parameters()),
-            "gpt": list(self.gpt.parameters()),
-            "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
-        }
-
-    def post_init_gpt2_config(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
-        seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
-        gpt_config = GPT2Config(
-            vocab_size=self.max_mel_tokens,
-            n_positions=seq_length,
-            n_ctx=seq_length,
-            n_embd=self.model_dim,
-            n_layer=self.layers,
-            n_head=self.heads,
-            gradient_checkpointing=False,
-            use_cache=True,
-        )
-        self.inference_model = GPT2InferenceModel(
-            gpt_config,
-            self.gpt,
-            self.mel_pos_embedding,
-            self.mel_embedding,
-            self.final_norm,
-            self.mel_head,
-            kv_cache=kv_cache,
-        )
-        # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
-        self.gpt.wte = self.mel_embedding
-
-        if use_deepspeed:
-            # init deepspeed inference engine
-            if use_deepspeed_f16:
-                self.gpt.wte = self.mel_embedding.half()
-                self.gpt.wpe = self.mel_pos_embedding.half()
-            self.ds_engine = deepspeed.init_inference(
-                model=self.inference_model.half(),  # Transformers models
-                mp_size=1,  # Number of GPU
-                dtype=torch.float16 if use_deepspeed_f16 else torch.float32,  # desired data type of output
-                replace_method="auto",  # Lets DS autmatically identify the layer to replace
-                replace_with_kernel_inject=True,  # replace the model with the kernel injector
-            )
-            self.inference_model = self.ds_engine.module.eval()
-
-    def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
-        inp = F.pad(input, (1, 0), value=start_token)
-        tar = F.pad(input, (0, 1), value=stop_token)
-        return inp, tar
-
-    def set_mel_padding(self, mel_input_tokens, mel_lengths):
-        """
-        Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
-        that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
-        preformatting to create a working TTS model.
-        """
-        # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
-        for b in range(len(mel_lengths)):
-            actual_end = mel_lengths[b]
-            if actual_end < mel_input_tokens.shape[-1]:
-                mel_input_tokens[b, actual_end:] = self.stop_mel_token
-        return mel_input_tokens
-
-    def get_logits(
-        self,
-        speech_conditioning_inputs,
-        first_inputs,
-        first_head,
-        second_inputs=None,
-        second_head=None,
-        prompt=None,
-        get_attns=False,
-        return_latent=False,
-        attn_mask_text=None,
-        attn_mask_mel=None,
-    ):
-        if prompt is not None and speech_conditioning_inputs is not None:
-            offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat(
-                    [speech_conditioning_inputs, prompt, first_inputs, second_inputs],
-                    dim=1,
-                )
-            else:
-                emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
-        elif speech_conditioning_inputs is not None:
-            offset = speech_conditioning_inputs.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
-            else:
-                emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
-        elif prompt is not None:
-            offset = prompt.shape[1]
-            if second_inputs is not None:
-                emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
-            else:
-                emb = torch.cat([prompt, first_inputs], dim=1)
-
-        # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
-        attn_mask = None
-        if attn_mask_text is not None:
-            attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
-            if prompt is not None:
-                attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
-                attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
-
-        gpt_out = self.gpt(
-            inputs_embeds=emb,
-            return_dict=True,
-            output_attentions=get_attns,
-            attention_mask=attn_mask,
-        )
-
-        if get_attns:
-            return gpt_out.attentions
-
-        enc = gpt_out.last_hidden_state[:, offset:]
-        enc = self.final_norm(enc)
-
-        if return_latent:
-            return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
-
-        first_logits = enc[:, : first_inputs.shape[1]]
-        first_logits = first_head(first_logits)
-        first_logits = first_logits.permute(0, 2, 1)
-        if second_inputs is not None:
-            second_logits = enc[:, -second_inputs.shape[1] :]
-            second_logits = second_head(second_logits)
-            second_logits = second_logits.permute(0, 2, 1)
-            return first_logits, second_logits
-        else:
-            return first_logits
-
-    def get_conditioning(self, speech_conditioning_input):
-        speech_conditioning_input = (
-            speech_conditioning_input.unsqueeze(1)
-            if len(speech_conditioning_input.shape) == 3
-            else speech_conditioning_input
-        )
-        conds = []
-        for j in range(speech_conditioning_input.shape[1]):
-            conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-        conds = torch.stack(conds, dim=1)
-        conds = conds.mean(dim=1)
-        return conds
-
-    def get_prompts(self, prompt_codes):
-        """
-        Create a prompt from the mel codes. This is used to condition the model on the mel codes.
-        Pad the prompt with start and stop mel tokens.
-        """
-        prompt = prompt_codes
-        if self.training:
-            prompt_len = random.randint(1, 9)  # in secs
-            prompt_len = prompt_len * 24  # in frames
-
-            if prompt_codes.shape[1] < prompt_len:
-                prompt_len = prompt_codes.shape[-1]
-                start = 0
-            else:
-                start = random.randint(0, prompt_codes.shape[-1] - prompt_len)
-
-            prompt = prompt_codes[:, start : start + prompt_len]
-
-        # add start and stop tokens
-        prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
-        prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
-        return prompt
-
-    # def get_prompts(self, prompt_codes):
-    #     """
-    #     Create a prompt from the mel codes. This is used to condition the model on the mel codes.
-    #     Pad the prompt with start and stop mel tokens.
-    #     """
-    #     prompt = prompt_codes
-    #     if self.training:
-    #         max_prompt_len = 9 * 24
-    #         if prompt_codes.shape[1] < max_prompt_len:
-    #             prompt = prompt_codes
-    #         else:
-    #             start = random.randint(0, prompt_codes.shape[1] - max_prompt_len)
-    #             prompt = prompt_codes[:, start : start + max_prompt_len]
-
-    #     # add start and stop tokens
-    #     prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
-    #     prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
-    #     return prompt
-
-    def forward(
-        self,
-        speech_conditioning_input,
-        text_inputs,
-        text_lengths,
-        mel_codes,
-        wav_lengths,
-        prompt_codes,
-        loss_weights=None,
-        text_first=True,
-        return_attentions=False,
-        return_latent=False,
-    ):
-        """
-        Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
-        (actuated by `text_first`).
-
-        speech_conditioning_input: MEL float tensor, (b,80,s)
-        text_inputs: long tensor, (b,t)
-        text_lengths: long tensor, (b,)
-        mel_inputs:  long tensor, (b,m)
-        wav_lengths: long tensor, (b,)
-
-        If return_attentions is specified, only logits are returned.
-        If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
-        """
-
-        # ❗ FIXIT
-        speech_conditioning_input = None
-        if self.max_conditioning_inputs == 0:
-            assert (
-                speech_conditioning_input is None
-            ), " ❗ speech_conditioning_input is not None, but max_conditioning_inputs == 0"
-
-        max_text_len = text_lengths.max()
-        # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
-        # Like [..., 186, 45, 45, 83] where actually it should end with 186.
-        # We take last 3 codes to prevent abrupt ending of the audio.
-        # TODO: This is might need some testing.
-        mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
-
-        # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
-        max_mel_len = mel_lengths.max()
-
-        if max_mel_len > mel_codes.shape[-1]:
-            mel_codes = F.pad(mel_codes, (0, max_mel_len - mel_codes.shape[-1]))
-
-        # mel_lengths[mel_lengths >= max_mel_len] = max_mel_len
-
-        # silence aware lengths, skip the silence tokens at the end of the mel codes.
-        silence = True
-        for idx, l in enumerate(mel_lengths):
-            length = l.item()
-            while silence:
-                if mel_codes[idx, length - 1] != 83:
-                    break
-                length -= 1
-            mel_lengths[idx] = length
-
-        # Lovely assertions
-        assert (
-            max_mel_len <= mel_codes.shape[-1]
-        ), f" ❗ max_mel_len ({max_mel_len}) > mel_codes.shape[-1] ({mel_codes.shape[-1]})"
-        assert (
-            max_text_len <= text_inputs.shape[-1]
-        ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
-
-        # Append stop token to text inputs
-        text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
-
-        # Append silence token to mel codes
-        mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
-
-        # Pad mel codes with STOP_MEL_TOKEN
-        mel_codes = self.set_mel_padding(mel_codes, mel_lengths)
-
-        # Compute speech conditioning input
-        conds = None
-        if speech_conditioning_input is not None:
-            if not return_latent:
-                # Compute speech conditioning input
-                speech_conditioning_input = (
-                    speech_conditioning_input.unsqueeze(1)
-                    if len(speech_conditioning_input.shape) == 3
-                    else speech_conditioning_input
-                )
-
-                conds = []
-                for j in range(speech_conditioning_input.shape[1]):
-                    conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-                conds = torch.stack(conds, dim=1)
-                if self.average_conditioning_embeddings:
-                    conds = conds.mean(dim=1).unsqueeze(1)
-            else:
-                # already computed
-                conds = speech_conditioning_input.unsqueeze(1)
-
-        # Build input and target tensors
-        # Prepend start token to inputs and append stop token to targets
-        text_inputs, text_targets = self.build_aligned_inputs_and_targets(
-            text_inputs, self.start_text_token, self.stop_text_token
-        )
-        mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
-            mel_codes, self.start_mel_token, self.stop_mel_token
-        )
-
-        # Set attn_mask
-        attn_mask_text = None
-        attn_mask_mel = None
-        if not return_latent:
-            attn_mask_text = torch.ones(
-                text_inputs.shape[0],
-                text_inputs.shape[1],
-                dtype=torch.bool,
-                device=text_inputs.device,
-            )
-            attn_mask_mel = torch.ones(
-                mel_codes.shape[0],
-                mel_codes.shape[1],
-                dtype=torch.bool,
-                device=mel_codes.device,
-            )
-
-            for idx, l in enumerate(text_lengths):
-                attn_mask_text[idx, l + 1 :] = 0.0
-
-            for idx, l in enumerate(mel_lengths):
-                attn_mask_mel[idx, l + 1 :] = 0.0
-
-        # Compute text embeddings + positional embeddings
-        # print(" > text input latent:", text_inputs)
-        text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        # Compute mel embeddings + positional embeddings
-        mel_emb = self.mel_embedding(mel_codes) + self.mel_pos_embedding(mel_codes)
-
-        # Compute prompt embeddings + positional embeddings
-        prompt = self.get_prompts(prompt_codes)
-
-        prompt_emb = self.mel_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
-
-        # Get logits
-        sub = -4  # don't ask me why 😄
-        if self.training:
-            sub = -1
-        text_logits, mel_logits = self.get_logits(
-            conds,
-            text_emb,
-            self.text_head,
-            mel_emb,
-            self.mel_head,
-            prompt=prompt_emb,
-            get_attns=return_attentions,
-            return_latent=return_latent,
-            attn_mask_text=attn_mask_text,
-            attn_mask_mel=attn_mask_mel,
-        )
-        if return_latent:
-            return mel_logits[:, :sub]  # sub to prevent bla.
-
-        if return_attentions:
-            return mel_logits
-
-        # Set paddings to -1 to ignore them in loss
-        for idx, l in enumerate(text_lengths):
-            text_targets[idx, l + 1 :] = -1
-
-        for idx, l in enumerate(mel_lengths):
-            mel_targets[idx, l + 1 :] = -1
-
-        # check if stoptoken is in every row of mel_targets
-        assert (mel_targets == self.stop_mel_token).sum() >= mel_targets.shape[
-            0
-        ], f" ❗ mel_targets does not contain stop token ({self.stop_mel_token}) in every row."
-
-        # Compute losses
-        loss_text = F.cross_entropy(
-            text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
-        )
-        loss_mel = F.cross_entropy(
-            mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
-        )
-
-        # if loss_weights is not None:
-        #     loss_text = loss_text * loss_weights[:, None]
-        #     loss_mel = loss_mel * loss_weights[:, None]
-        return loss_text.mean(), loss_mel.mean(), mel_logits
-
-    def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
-        """
-        Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
-        model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
-        """
-        # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
-        # chopping the inputs by the maximum actual length.
-        max_text_len = text_lengths.max()
-        text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
-
-        speech_conditioning_input = (
-            speech_conditioning_input.unsqueeze(1)
-            if len(speech_conditioning_input.shape) == 3
-            else speech_conditioning_input
-        )
-        conds = []
-        for j in range(speech_conditioning_input.shape[1]):
-            conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-        conds = torch.stack(conds, dim=1)
-        if self.average_conditioning_embeddings:
-            conds = conds.mean(dim=1).unsqueeze(1)
-
-        text_inputs, text_targets = self.build_aligned_inputs_and_targets(
-            text_inputs, self.start_text_token, self.stop_text_token
-        )
-        text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
-        text_logits = self.get_logits(conds, text_emb, self.text_head)
-        loss_text = F.cross_entropy(text_logits, text_targets.long())
-        return loss_text.mean()
-
-    def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
-        """
-        Performs autoregressive modeling on only speech data.
-        """
-        assert self.max_mel_tokens >= mel_codes.shape[1], f"{mel_codes.shape[1]}"
-
-        # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
-        # chopping the inputs by the maximum actual length.
-        max_mel_len = wav_lengths.max() // self.mel_length_compression
-        mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
-        mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
-        if raw_mels is not None:
-            raw_mels = raw_mels[:, :, : max_mel_len * 4]
-
-        speech_conditioning_input = (
-            speech_conditioning_input.unsqueeze(1)
-            if len(speech_conditioning_input.shape) == 3
-            else speech_conditioning_input
-        )
-        conds = []
-        for j in range(speech_conditioning_input.shape[1]):
-            conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
-        conds = torch.stack(conds, dim=1)
-        if self.average_conditioning_embeddings:
-            conds = conds.mean(dim=1).unsqueeze(1)
-
-        mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
-            mel_codes, self.start_mel_token, self.stop_mel_token
-        )
-        if raw_mels is not None:
-            mel_inp = F.pad(raw_mels, (0, 4))
-        else:
-            mel_inp = mel_codes
-        mel_emb = self.mel_embedding(mel_inp)
-        mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
-        mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
-        loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
-        return loss_mel.mean()
-
-    def get_generator(self, fake_inputs, **hf_generate_kwargs):
-        return self.inference_model.generate_stream(
-            fake_inputs,
-            bos_token_id=self.start_mel_token,
-            pad_token_id=self.stop_mel_token,
-            eos_token_id=self.stop_mel_token,
-            max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
-            do_stream=True,
-            **hf_generate_kwargs,
-        )
-
-    def compute_embeddings(
-        self,
-        speech_conditioning_latent,
-        text_inputs,
-        input_tokens=None,
-        prompt_codes=None,
-        pad_input_text=False,
-    ):
-        if pad_input_text and text_inputs.shape[1] < 250:
-            text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
-        else:
-            text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
-        text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
-
-        emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        print(" > Text inputs:", text_inputs)
-        if prompt_codes is not None:
-            prompt_codes = self.get_prompts(prompt_codes)
-            prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
-            print(" > Prompt inputs:", prompt_codes)
-            print(" > Prompt inputs shape:", prompt_codes.shape)
-            emb = torch.cat([prompt_emb, emb], dim=1)
-
-        if speech_conditioning_latent is not None:
-            conds = speech_conditioning_latent.unsqueeze(1)
-            emb = torch.cat([conds, emb], dim=1)
-
-        self.inference_model.store_prefix_emb(emb)
-
-        fake_inputs = torch.full(
-            (
-                emb.shape[0],
-                emb.shape[1] + 1,  # +1 for the start_mel_token
-            ),
-            fill_value=1,
-            dtype=torch.long,
-            device=text_inputs.device,
-        )
-        fake_inputs[:, -1] = self.start_mel_token
-
-        if input_tokens is not None:
-            fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
-        return fake_inputs
-
-    def inference_speech(
-        self,
-        speech_conditioning_latent,
-        text_inputs,
-        input_tokens=None,
-        prompt_codes=None,
-        pad_input_text=False,
-        **hf_generate_kwargs,
-    ):
-        if pad_input_text and text_inputs.shape[1] < 250:
-            text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
-        else:
-            text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
-        text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
-
-        emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
-
-        print(" > Text inputs:", text_inputs)
-        if prompt_codes is not None:
-            prompt_codes = self.get_prompts(prompt_codes)
-            prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
-            print(" > Prompt inputs:", prompt_codes)
-            print(" > Prompt inputs shape:", prompt_codes.shape)
-            emb = torch.cat([prompt_emb, emb], dim=1)
-
-        if speech_conditioning_latent is not None:
-            conds = speech_conditioning_latent.unsqueeze(1)
-            emb = torch.cat([conds, emb], dim=1)
-
-        self.inference_model.store_prefix_emb(emb)
-
-        fake_inputs = torch.full(
-            (
-                emb.shape[0],
-                emb.shape[1] + 1,  # +1 for the start_mel_token
-            ),
-            fill_value=1,
-            dtype=torch.long,
-            device=text_inputs.device,
-        )
-        fake_inputs[:, -1] = self.start_mel_token
-
-        if input_tokens is not None:
-            fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
-
-        gen = self.inference_model.generate(
-            fake_inputs,
-            bos_token_id=self.start_mel_token,
-            pad_token_id=self.stop_mel_token,
-            eos_token_id=self.stop_mel_token,
-            max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
-            **hf_generate_kwargs,
-        )
-        if "return_dict_in_generate" in hf_generate_kwargs:
-            return gen.sequences[:, fake_inputs.shape[1] :], gen
-        return gen[:, fake_inputs.shape[1] :]
-
-    # Turns the (utterly insane) output of HF.generate() into a far more sane output:
-    # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
-    def make_hf_generate_attentions_sane(self, attentions):
-        layers = [[] for _ in range(len(attentions[0]))]
-        full_attention_size = attentions[-1][0].shape[-1]
-        for i, gen in enumerate(attentions):
-            for j, lyr in enumerate(gen):
-                layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
-        catted = []
-        for lyr in layers:
-            catted.append(torch.cat(lyr, dim=2))
-        return catted
-
-    def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
-        """
-        This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
-        """
-        text_padding = num_conds + 2
-        num_text = text.shape[-1]
-        num_context = num_text + text_padding
-        assert num_context + 1 == attentions[0][0].shape[-1]
-        attentions = self.make_hf_generate_attentions_sane(attentions)
-        results = [torch.empty_like(codes) for _ in range(len(attentions))]
-        for l, layer in enumerate(attentions):
-            dec_context = layer[:, :, num_context:, :]
-            # Mask out everything that isn't text (including the start token, which gets a LOT of attention)
-            dec_context[:, :, :, : text_padding + 1] = 0
-            dec_context[:, :, :, num_context:] = 0
-            for h in range(dec_context.shape[1]):
-                dec_context_indices = torch.argmax(dec_context[0, h], dim=-1)
-                print(f"layer_{l};head_{h}: " + str(dec_context_indices))
-        for t, att_tok in enumerate(attentions):
-            combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
-            for lyr in att_tok:
-                token_to_text_attentions = lyr[:, :, -1, text_padding : (text_padding + num_text)].sum(dim=1)
-                combined_attention_weights = combined_attention_weights + token_to_text_attentions
-                break
-            most_attended_text_token = combined_attention_weights.argmax(dim=-1)
-            results[:, t] = most_attended_text_token
-        eos_token_mask = codes != self.stop_mel_token
-        return results * eos_token_mask
-
-
-@register_model
-def register_unified_voice_prompt(opt_net, opt):
-    return UnifiedVoice(**opt_get(opt_net, ["kwargs"], {}))
-
-
-if __name__ == "__main__":
-    gpt = UnifiedVoice(
-        model_dim=256,
-        heads=4,
-        train_solo_embeddings=True,
-        use_mel_codes_as_input=True,
-        max_conditioning_inputs=4,
-        freeze_everything_but_position_embeddings=True,
-    )
-    l = gpt(
-        torch.randn(2, 3, 80, 800),
-        torch.randint(high=256, size=(2, 120)),
-        torch.tensor([32, 120]),
-        torch.randint(high=8192, size=(2, 250)),
-        torch.tensor([250 * 256, 195 * 256]),
-    )
-    # gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py
index b7945d6e..9c628276 100644
--- a/tests/zoo_tests/test_models.py
+++ b/tests/zoo_tests/test_models.py
@@ -3,13 +3,19 @@ import glob
 import os
 import shutil
 
+import torch
+
 from tests import get_tests_data_path, get_tests_output_path, run_cli
 from TTS.tts.utils.languages import LanguageManager
 from TTS.tts.utils.speakers import SpeakerManager
 from TTS.utils.generic_utils import get_user_data_dir
 from TTS.utils.manage import ModelManager
 
-MODELS_WITH_SEP_TESTS = ["bark", "xtts"]
+MODELS_WITH_SEP_TESTS = [
+    "tts_models/multilingual/multi-dataset/bark",
+    "tts_models/en/multi-dataset/tortoise-v2",
+    "tts_models/multilingual/multi-dataset/xtts_v1",
+]
 
 
 def run_models(offset=0, step=1):
@@ -17,7 +23,8 @@ def run_models(offset=0, step=1):
     print(" > Run synthesizer with all the models.")
     output_path = os.path.join(get_tests_output_path(), "output.wav")
     manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
-    model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS]
+    model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS]
+    print("Model names:", model_names)
     for model_name in model_names[offset::step]:
         print(f"\n > Run - {model_name}")
         model_path, _, _ = manager.download_model(model_name)
@@ -67,23 +74,55 @@ def run_models(offset=0, step=1):
 
 
 def test_xtts():
+    """XTTS is too big to run on github actions. We need to test it locally"""
     output_path = os.path.join(get_tests_output_path(), "output.wav")
     speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
-    run_cli(
-        "yes | "
-        f"tts --model_name  tts_models/multilingual/multi-dataset/xtts_v1 "
-        f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
-        f'--speaker_wav "{speaker_wav}" --language_idx "en"'
-    )
+    use_gpu = torch.cuda.is_available()
+    if use_gpu:
+        run_cli(
+            "yes | "
+            f"tts --model_name  tts_models/multilingual/multi-dataset/xtts_v1 "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
+            f'--speaker_wav "{speaker_wav}" --language_idx "en"'
+        )
+    else:
+        run_cli(
+            "yes | "
+            f"tts --model_name  tts_models/multilingual/multi-dataset/xtts_v1 "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
+            f'--speaker_wav "{speaker_wav}" --language_idx "en"'
+        )
+
+
+def test_tortoise():
+    output_path = os.path.join(get_tests_output_path(), "output.wav")
+    use_gpu = torch.cuda.is_available()
+    if use_gpu:
+        run_cli(
+            f" tts --model_name  tts_models/en/multi-dataset/tortoise-v2 "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
+        )
+    else:
+        run_cli(
+            f" tts --model_name  tts_models/en/multi-dataset/tortoise-v2 "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
+        )
 
 
 def test_bark():
     """Bark is too big to run on github actions. We need to test it locally"""
     output_path = os.path.join(get_tests_output_path(), "output.wav")
-    run_cli(
-        f" tts --model_name  tts_models/multilingual/multi-dataset/bark "
-        f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
-    )
+    use_gpu = torch.cuda.is_available()
+    if use_gpu:
+        run_cli(
+            f" tts --model_name  tts_models/multilingual/multi-dataset/bark "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
+        )
+    else:
+        run_cli(
+            f" tts --model_name  tts_models/multilingual/multi-dataset/bark "
+            f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
+        )
 
 
 def test_voice_conversion():

From 155c5fc0bde77c6ace1e6e6982da2b13f19ddd2f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <erogol@hotmail.com>
Date: Fri, 29 Sep 2023 23:44:09 +0200
Subject: [PATCH 26/37] v0.17.6

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 8b5334dc..5543a76e 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.5
+0.17.6

From f133b9d2d7ddb154769fba43e421ec71e83ff4c5 Mon Sep 17 00:00:00 2001
From: Anupam Maurya <anupammaurya6767@gmail.com>
Date: Mon, 2 Oct 2023 16:21:55 +0530
Subject: [PATCH 27/37] Upgrade and Optimize TTS Code in
 extractttsspectrogram.ipynb (#3012)

---
 notebooks/ExtractTTSpectrogram.ipynb | 181 +++++++++++++++------------
 1 file changed, 103 insertions(+), 78 deletions(-)

diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb
index a257b6bf..9acc9929 100644
--- a/notebooks/ExtractTTSpectrogram.ipynb
+++ b/notebooks/ExtractTTSpectrogram.ipynb
@@ -13,15 +13,15 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "%load_ext autoreload\n",
-    "%autoreload 2\n",
     "import os\n",
     "import sys\n",
     "import torch\n",
     "import importlib\n",
     "import numpy as np\n",
-    "from tqdm import tqdm as tqdm\n",
+    "from tqdm import tqdm\n",
     "from torch.utils.data import DataLoader\n",
+    "import soundfile as sf\n",
+    "import pickle\n",
     "from TTS.tts.datasets.dataset import TTSDataset\n",
     "from TTS.tts.layers.losses import L1LossMasked\n",
     "from TTS.utils.audio import AudioProcessor\n",
@@ -33,8 +33,8 @@
     "\n",
     "%matplotlib inline\n",
     "\n",
-    "import os\n",
-    "os.environ['CUDA_VISIBLE_DEVICES']='2'"
+    "# Configure CUDA visibility\n",
+    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
    ]
   },
   {
@@ -43,6 +43,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Function to create directories and file names\n",
     "def set_filename(wav_path, out_path):\n",
     "    wav_file = os.path.basename(wav_path)\n",
     "    file_name = wav_file.split('.')[0]\n",
@@ -61,6 +62,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Paths and configurations\n",
     "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
     "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
     "DATASET = \"ljspeech\"\n",
@@ -73,12 +75,15 @@
     "QUANTIZE_BIT = None\n",
     "DRY_RUN = False   # if False, does not generate output files, only computes loss and visuals.\n",
     "\n",
+    "# Check CUDA availability\n",
     "use_cuda = torch.cuda.is_available()\n",
     "print(\" > CUDA enabled: \", use_cuda)\n",
     "\n",
+    "# Load the configuration\n",
     "C = load_config(CONFIG_PATH)\n",
     "C.audio['do_trim_silence'] = False  # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
-    "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
+    "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
+    "print(C['r'])"
    ]
   },
   {
@@ -87,14 +92,13 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "print(C['r'])\n",
-    "# if the vocabulary was passed, replace the default\n",
+    "# If the vocabulary was passed, replace the default\n",
     "if 'characters' in C and C['characters']:\n",
     "    symbols, phonemes = make_symbols(**C.characters)\n",
     "\n",
-    "# load the model\n",
+    "# Load the model\n",
     "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
-    "# TODO: multiple speaker\n",
+    "# TODO: multiple speakers\n",
     "model = setup_model(C)\n",
     "model.load_checkpoint(C, MODEL_FILE, eval=True)"
    ]
@@ -105,11 +109,12 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Load the preprocessor based on the dataset\n",
     "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
     "preprocessor = getattr(preprocessor, DATASET.lower())\n",
     "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
     "dataset = TTSDataset(\n",
-    "    checkpoint[\"config\"][\"r\"],\n",
+    "    C,\n",
     "    C.text_cleaner,\n",
     "    False,\n",
     "    ap,\n",
@@ -124,6 +129,24 @@
     ")\n"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Initialize lists for storing results\n",
+    "file_idxs = []\n",
+    "metadata = []\n",
+    "losses = []\n",
+    "postnet_losses = []\n",
+    "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
+    "\n",
+    "# Create log file\n",
+    "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
+    "log_file = open(log_file_path, \"w\")"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -137,83 +160,85 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "import pickle\n",
-    "\n",
-    "file_idxs = []\n",
-    "metadata = []\n",
-    "losses = []\n",
-    "postnet_losses = []\n",
-    "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
+    "# Start processing with a progress bar\n",
     "with torch.no_grad():\n",
-    "    for data in tqdm(loader):\n",
-    "        # setup input data\n",
-    "        text_input = data[0]\n",
-    "        text_lengths = data[1]\n",
-    "        linear_input = data[3]\n",
-    "        mel_input = data[4]\n",
-    "        mel_lengths = data[5]\n",
-    "        stop_targets = data[6]\n",
-    "        item_idx = data[7]\n",
+    "    for data in tqdm(loader, desc=\"Processing\"):\n",
+    "        try:\n",
+    "            # setup input data\n",
+    "            text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
     "\n",
-    "        # dispatch data to GPU\n",
-    "        if use_cuda:\n",
-    "            text_input = text_input.cuda()\n",
-    "            text_lengths = text_lengths.cuda()\n",
-    "            mel_input = mel_input.cuda()\n",
-    "            mel_lengths = mel_lengths.cuda()\n",
+    "            # dispatch data to GPU\n",
+    "            if use_cuda:\n",
+    "                text_input = text_input.cuda()\n",
+    "                text_lengths = text_lengths.cuda()\n",
+    "                mel_input = mel_input.cuda()\n",
+    "                mel_lengths = mel_lengths.cuda()\n",
     "\n",
-    "        mask = sequence_mask(text_lengths)\n",
-    "        mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
-    "        \n",
-    "        # compute loss\n",
-    "        loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
-    "        loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
-    "        losses.append(loss.item())\n",
-    "        postnet_losses.append(loss_postnet.item())\n",
+    "            mask = sequence_mask(text_lengths)\n",
+    "            mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
     "\n",
-    "        # compute mel specs from linear spec if model is Tacotron\n",
-    "        if C.model == \"Tacotron\":\n",
-    "            mel_specs = []\n",
-    "            postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
-    "            for b in range(postnet_outputs.shape[0]):\n",
-    "                postnet_output = postnet_outputs[b]\n",
-    "                mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
-    "            postnet_outputs = torch.stack(mel_specs)\n",
-    "        elif C.model == \"Tacotron2\":\n",
-    "            postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
-    "        alignments = alignments.detach().cpu().numpy()\n",
+    "            # compute loss\n",
+    "            loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
+    "            loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
+    "            losses.append(loss.item())\n",
+    "            postnet_losses.append(loss_postnet.item())\n",
     "\n",
-    "        if not DRY_RUN:\n",
-    "            for idx in range(text_input.shape[0]):\n",
-    "                wav_file_path = item_idx[idx]\n",
-    "                wav = ap.load_wav(wav_file_path)\n",
-    "                file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
-    "                file_idxs.append(file_name)\n",
+    "            # compute mel specs from linear spec if the model is Tacotron\n",
+    "            if C.model == \"Tacotron\":\n",
+    "                mel_specs = []\n",
+    "                postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
+    "                for b in range(postnet_outputs.shape[0]):\n",
+    "                    postnet_output = postnet_outputs[b]\n",
+    "                    mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
+    "                postnet_outputs = torch.stack(mel_specs)\n",
+    "            elif C.model == \"Tacotron2\":\n",
+    "                postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
+    "            alignments = alignments.detach().cpu().numpy()\n",
     "\n",
-    "                # quantize and save wav\n",
-    "                if QUANTIZED_WAV:\n",
-    "                    wavq = ap.quantize(wav)\n",
-    "                    np.save(wavq_path, wavq)\n",
+    "            if not DRY_RUN:\n",
+    "                for idx in range(text_input.shape[0]):\n",
+    "                    wav_file_path = item_idx[idx]\n",
+    "                    wav = ap.load_wav(wav_file_path)\n",
+    "                    file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
+    "                    file_idxs.append(file_name)\n",
     "\n",
-    "                # save TTS mel\n",
-    "                mel = postnet_outputs[idx]\n",
-    "                mel_length = mel_lengths[idx]\n",
-    "                mel = mel[:mel_length, :].T\n",
-    "                np.save(mel_path, mel)\n",
+    "                    # quantize and save wav\n",
+    "                    if QUANTIZED_WAV:\n",
+    "                        wavq = ap.quantize(wav)\n",
+    "                        np.save(wavq_path, wavq)\n",
     "\n",
-    "                metadata.append([wav_file_path, mel_path])\n",
+    "                    # save TTS mel\n",
+    "                    mel = postnet_outputs[idx]\n",
+    "                    mel_length = mel_lengths[idx]\n",
+    "                    mel = mel[:mel_length, :].T\n",
+    "                    np.save(mel_path, mel)\n",
     "\n",
-    "    # for wavernn\n",
-    "    if not DRY_RUN:\n",
-    "        pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\"))      \n",
-    "    \n",
-    "    # for pwgan\n",
-    "    with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
-    "        for data in metadata:\n",
-    "            f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
+    "                    metadata.append([wav_file_path, mel_path])\n",
     "\n",
-    "    print(np.mean(losses))\n",
-    "    print(np.mean(postnet_losses))"
+    "        except Exception as e:\n",
+    "            log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
+    "\n",
+    "    # Calculate and log mean losses\n",
+    "    mean_loss = np.mean(losses)\n",
+    "    mean_postnet_loss = np.mean(postnet_losses)\n",
+    "    log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
+    "    log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
+    "\n",
+    "# Close the log file\n",
+    "log_file.close()\n",
+    "\n",
+    "# For wavernn\n",
+    "if not DRY_RUN:\n",
+    "    pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
+    "\n",
+    "# For pwgan\n",
+    "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
+    "    for data in metadata:\n",
+    "        f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
+    "\n",
+    "# Print mean losses\n",
+    "print(f\"Mean Loss: {mean_loss}\")\n",
+    "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
    ]
   },
   {

From 21501362106224a3918af5b03a8ac92cdfc45e11 Mon Sep 17 00:00:00 2001
From: OPERATOR <78286476+OPPEYRADY@users.noreply.github.com>
Date: Mon, 2 Oct 2023 06:53:36 -0400
Subject: [PATCH 28/37] None is not able to be read for "XTTS", fixes crash if
 its set to None. (#3009)

---
 TTS/api.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/TTS/api.py b/TTS/api.py
index 1eb0b510..e1d167a9 100644
--- a/TTS/api.py
+++ b/TTS/api.py
@@ -17,7 +17,7 @@ class TTS(nn.Module):
 
     def __init__(
         self,
-        model_name: str = None,
+        model_name: str = "",
         model_path: str = None,
         config_path: str = None,
         vocoder_path: str = None,
@@ -105,13 +105,14 @@ class TTS(nn.Module):
 
     @property
     def is_multi_lingual(self):
-        # TODO: fix this
-        if "xtts" in self.model_name:
+        # Not sure what sets this to None, but applied a fix to prevent crashing.
+        if isinstance(self.model_name, str) and "xtts" in self.model_name:
             return True
         if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
             return self.synthesizer.tts_model.language_manager.num_languages > 1
         return False
 
+
     @property
     def speakers(self):
         if not self.is_multi_speaker:

From e5e0cbffc9d0a9b6ea47b3e158bb218a54499633 Mon Sep 17 00:00:00 2001
From: Julian Weber <julian.weber@hotmail.fr>
Date: Fri, 6 Oct 2023 18:34:06 +0200
Subject: [PATCH 29/37] =?UTF-8?q?Streaming=20inference=20for=20XTTS=20?=
 =?UTF-8?q?=F0=9F=9A=80=20(#3035)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 TTS/.models.json                        |    8 +-
 TTS/tts/layers/xtts/gpt.py              |   24 +-
 TTS/tts/layers/xtts/hifigan_decoder.py  |  742 ++++++++++++++++
 TTS/tts/layers/xtts/stream_generator.py | 1057 +++++++++++++++++++++++
 TTS/tts/layers/xtts/tokenizer.py        |    3 +
 TTS/tts/models/xtts.py                  |  373 +++++---
 docs/source/models/xtts.md              |   86 +-
 tests/zoo_tests/test_models.py          |   28 +
 8 files changed, 2192 insertions(+), 129 deletions(-)
 create mode 100644 TTS/tts/layers/xtts/hifigan_decoder.py
 create mode 100644 TTS/tts/layers/xtts/stream_generator.py

diff --git a/TTS/.models.json b/TTS/.models.json
index a893f708..ba7b5f62 100644
--- a/TTS/.models.json
+++ b/TTS/.models.json
@@ -5,12 +5,12 @@
                 "xtts_v1": {
                     "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
                     "hf_url": [
-                        "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
-                        "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
-                        "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
+                        "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
+                        "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
+                        "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
                     ],
                     "default_vocoder": null,
-                    "commit": "e9a1953e",
+                    "commit": "e5140314",
                     "license": "CPML",
                     "contact": "info@coqui.ai",
                     "tos_required": true
diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py
index 2a821a5d..88ce100c 100644
--- a/TTS/tts/layers/xtts/gpt.py
+++ b/TTS/tts/layers/xtts/gpt.py
@@ -172,7 +172,7 @@ class GPT(nn.Module):
             "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
         }
 
-    def init_gpt_for_inference(self, kv_cache=True):
+    def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
         seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
         gpt_config = GPT2Config(
             vocab_size=self.max_mel_tokens,
@@ -195,6 +195,17 @@ class GPT(nn.Module):
         )
         self.gpt.wte = self.mel_embedding
 
+        if use_deepspeed:
+            import deepspeed
+            self.ds_engine = deepspeed.init_inference(
+                model=self.gpt_inference.half(),  # Transformers models
+                mp_size=1,  # Number of GPU
+                dtype=torch.float32,  # desired data type of output
+                replace_method="auto",  # Lets DS autmatically identify the layer to replace
+                replace_with_kernel_inject=True,  # replace the model with the kernel injector
+            )
+            self.gpt_inference = self.ds_engine.module.eval()
+
     def set_inputs_and_targets(self, input, start_token, stop_token):
         inp = F.pad(input, (1, 0), value=start_token)
         tar = F.pad(input, (0, 1), value=stop_token)
@@ -543,3 +554,14 @@ class GPT(nn.Module):
         if "return_dict_in_generate" in hf_generate_kwargs:
             return gen.sequences[:, gpt_inputs.shape[1] :], gen
         return gen[:, gpt_inputs.shape[1] :]
+
+    def get_generator(self, fake_inputs, **hf_generate_kwargs):
+        return self.gpt_inference.generate_stream(
+            fake_inputs,
+            bos_token_id=self.start_audio_token,
+            pad_token_id=self.stop_audio_token,
+            eos_token_id=self.stop_audio_token,
+            max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
+            do_stream=True,
+            **hf_generate_kwargs,
+        )
diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py
new file mode 100644
index 00000000..6439b455
--- /dev/null
+++ b/TTS/tts/layers/xtts/hifigan_decoder.py
@@ -0,0 +1,742 @@
+import torch
+from torch import nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+import torchaudio
+
+from TTS.utils.io import load_fsspec
+
+
+LRELU_SLOPE = 0.1
+
+
+def get_padding(k, d):
+    return int((k * d - d) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+    """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
+
+    Network::
+
+        x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
+        |--------------------------------------------------------------------------------------------------|
+
+
+    Args:
+        channels (int): number of hidden channels for the convolutional layers.
+        kernel_size (int): size of the convolution filter in each layer.
+        dilations (list): list of dilation value for each conv layer in a block.
+    """
+
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+
+    def forward(self, x):
+        """
+        Args:
+            x (Tensor): input tensor.
+        Returns:
+            Tensor: output tensor.
+        Shapes:
+            x: [B, C, T]
+        """
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
+
+    Network::
+
+        x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
+        |---------------------------------------------------|
+
+
+    Args:
+        channels (int): number of hidden channels for the convolutional layers.
+        kernel_size (int): size of the convolution filter in each layer.
+        dilations (list): list of dilation value for each conv layer in a block.
+    """
+
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super().__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class HifiganGenerator(torch.nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        resblock_type,
+        resblock_dilation_sizes,
+        resblock_kernel_sizes,
+        upsample_kernel_sizes,
+        upsample_initial_channel,
+        upsample_factors,
+        inference_padding=5,
+        cond_channels=0,
+        conv_pre_weight_norm=True,
+        conv_post_weight_norm=True,
+        conv_post_bias=True,
+        cond_in_each_up_layer=False,
+    ):
+        r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
+
+        Network:
+            x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
+                                                 ..          -> zI ---|
+                                              resblockN_kNx1 -> zN ---'
+
+        Args:
+            in_channels (int): number of input tensor channels.
+            out_channels (int): number of output tensor channels.
+            resblock_type (str): type of the `ResBlock`. '1' or '2'.
+            resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
+            resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
+            upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
+            upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
+                for each consecutive upsampling layer.
+            upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
+            inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
+        """
+        super().__init__()
+        self.inference_padding = inference_padding
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_factors)
+        self.cond_in_each_up_layer = cond_in_each_up_layer
+
+        # initial upsampling layers
+        self.conv_pre = weight_norm(
+            Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
+        )
+        resblock = ResBlock1 if resblock_type == "1" else ResBlock2
+        # upsampling layers
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+        # MRF blocks
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for _, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+        # post convolution layer
+        self.conv_post = weight_norm(
+            Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
+        )
+        if cond_channels > 0:
+            self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
+
+        if not conv_pre_weight_norm:
+            remove_weight_norm(self.conv_pre)
+
+        if not conv_post_weight_norm:
+            remove_weight_norm(self.conv_post)
+
+        if self.cond_in_each_up_layer:
+            self.conds = nn.ModuleList()
+            for i in range(len(self.ups)):
+                ch = upsample_initial_channel // (2 ** (i + 1))
+                self.conds.append(nn.Conv1d(cond_channels, ch, 1))
+
+    def forward(self, x, g=None):
+        """
+        Args:
+            x (Tensor): feature input tensor.
+            g (Tensor): global conditioning input tensor.
+
+        Returns:
+            Tensor: output waveform.
+
+        Shapes:
+            x: [B, C, T]
+            Tensor: [B, 1, T]
+        """
+        o = self.conv_pre(x)
+        if hasattr(self, "cond_layer"):
+            o = o + self.cond_layer(g)
+        for i in range(self.num_upsamples):
+            o = F.leaky_relu(o, LRELU_SLOPE)
+            o = self.ups[i](o)
+
+            if self.cond_in_each_up_layer:
+                o = o + self.conds[i](g)
+
+            z_sum = None
+            for j in range(self.num_kernels):
+                if z_sum is None:
+                    z_sum = self.resblocks[i * self.num_kernels + j](o)
+                else:
+                    z_sum += self.resblocks[i * self.num_kernels + j](o)
+            o = z_sum / self.num_kernels
+        o = F.leaky_relu(o)
+        o = self.conv_post(o)
+        o = torch.tanh(o)
+        return o
+
+    @torch.no_grad()
+    def inference(self, c):
+        """
+        Args:
+            x (Tensor): conditioning input tensor.
+
+        Returns:
+            Tensor: output waveform.
+
+        Shapes:
+            x: [B, C, T]
+            Tensor: [B, 1, T]
+        """
+        c = c.to(self.conv_pre.weight.device)
+        c = torch.nn.functional.pad(
+            c, (self.inference_padding, self.inference_padding), "replicate"
+        )
+        return self.forward(c)
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+    def load_checkpoint(
+        self, config, checkpoint_path, eval=False, cache=False
+    ):  # pylint: disable=unused-argument, redefined-builtin
+        state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+        self.load_state_dict(state["model"])
+        if eval:
+            self.eval()
+            assert not self.training
+            self.remove_weight_norm()
+
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=8):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel),
+            nn.Sigmoid(),
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y
+
+
+class SEBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
+        super(SEBasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.se = SELayer(planes, reduction)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.relu(out)
+        out = self.bn1(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.se(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+        return out
+
+
+def set_init_dict(model_dict, checkpoint_state, c):
+    # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
+    for k, v in checkpoint_state.items():
+        if k not in model_dict:
+            print(" | > Layer missing in the model definition: {}".format(k))
+    # 1. filter out unnecessary keys
+    pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
+    # 2. filter out different size layers
+    pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
+    # 3. skip reinit layers
+    if c.has("reinit_layers") and c.reinit_layers is not None:
+        for reinit_layer_name in c.reinit_layers:
+            pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
+    # 4. overwrite entries in the existing state dict
+    model_dict.update(pretrained_dict)
+    print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
+    return model_dict
+
+
+class PreEmphasis(nn.Module):
+    def __init__(self, coefficient=0.97):
+        super().__init__()
+        self.coefficient = coefficient
+        self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
+
+    def forward(self, x):
+        assert len(x.size()) == 2
+
+        x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
+        return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
+
+
+
+class ResNetSpeakerEncoder(nn.Module):
+    """This is copied from 🐸TTS to remove it from the dependencies.
+    """
+
+    # pylint: disable=W0102
+    def __init__(
+        self,
+        input_dim=64,
+        proj_dim=512,
+        layers=[3, 4, 6, 3],
+        num_filters=[32, 64, 128, 256],
+        encoder_type="ASP",
+        log_input=False,
+        use_torch_spec=False,
+        audio_config=None,
+    ):
+        super(ResNetSpeakerEncoder, self).__init__()
+
+        self.encoder_type = encoder_type
+        self.input_dim = input_dim
+        self.log_input = log_input
+        self.use_torch_spec = use_torch_spec
+        self.audio_config = audio_config
+        self.proj_dim = proj_dim
+
+        self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
+        self.relu = nn.ReLU(inplace=True)
+        self.bn1 = nn.BatchNorm2d(num_filters[0])
+
+        self.inplanes = num_filters[0]
+        self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
+        self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
+        self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
+        self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
+
+        self.instancenorm = nn.InstanceNorm1d(input_dim)
+
+        if self.use_torch_spec:
+            self.torch_spec = torch.nn.Sequential(
+                PreEmphasis(audio_config["preemphasis"]),
+                torchaudio.transforms.MelSpectrogram(
+                    sample_rate=audio_config["sample_rate"],
+                    n_fft=audio_config["fft_size"],
+                    win_length=audio_config["win_length"],
+                    hop_length=audio_config["hop_length"],
+                    window_fn=torch.hamming_window,
+                    n_mels=audio_config["num_mels"],
+                ),
+            )
+
+        else:
+            self.torch_spec = None
+
+        outmap_size = int(self.input_dim / 8)
+
+        self.attention = nn.Sequential(
+            nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
+            nn.ReLU(),
+            nn.BatchNorm1d(128),
+            nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
+            nn.Softmax(dim=2),
+        )
+
+        if self.encoder_type == "SAP":
+            out_dim = num_filters[3] * outmap_size
+        elif self.encoder_type == "ASP":
+            out_dim = num_filters[3] * outmap_size * 2
+        else:
+            raise ValueError("Undefined encoder")
+
+        self.fc = nn.Linear(out_dim, proj_dim)
+
+        self._init_layers()
+
+    def _init_layers(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def create_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    # pylint: disable=R0201
+    def new_parameter(self, *size):
+        out = nn.Parameter(torch.FloatTensor(*size))
+        nn.init.xavier_normal_(out)
+        return out
+
+    def forward(self, x, l2_norm=False):
+        """Forward pass of the model.
+
+        Args:
+            x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
+                to compute the spectrogram on-the-fly.
+            l2_norm (bool): Whether to L2-normalize the outputs.
+
+        Shapes:
+            - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
+        """
+        x.squeeze_(1)
+        # if you torch spec compute it otherwise use the mel spec computed by the AP
+        if self.use_torch_spec:
+            x = self.torch_spec(x)
+
+        if self.log_input:
+            x = (x + 1e-6).log()
+        x = self.instancenorm(x).unsqueeze(1)
+
+        x = self.conv1(x)
+        x = self.relu(x)
+        x = self.bn1(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = x.reshape(x.size()[0], -1, x.size()[-1])
+
+        w = self.attention(x)
+
+        if self.encoder_type == "SAP":
+            x = torch.sum(x * w, dim=2)
+        elif self.encoder_type == "ASP":
+            mu = torch.sum(x * w, dim=2)
+            sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
+            x = torch.cat((mu, sg), 1)
+
+        x = x.view(x.size()[0], -1)
+        x = self.fc(x)
+
+        if l2_norm:
+            x = torch.nn.functional.normalize(x, p=2, dim=1)
+        return x
+
+    def load_checkpoint(
+        self,
+        checkpoint_path: str,
+        eval: bool = False,
+        use_cuda: bool = False,
+        criterion=None,
+        cache=False,
+    ):
+        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
+        try:
+            self.load_state_dict(state["model"])
+            print(" > Model fully restored. ")
+        except (KeyError, RuntimeError) as error:
+            # If eval raise the error
+            if eval:
+                raise error
+
+            print(" > Partial model initialization.")
+            model_dict = self.state_dict()
+            model_dict = set_init_dict(model_dict, state["model"])
+            self.load_state_dict(model_dict)
+            del model_dict
+
+        # load the criterion for restore_path
+        if criterion is not None and "criterion" in state:
+            try:
+                criterion.load_state_dict(state["criterion"])
+            except (KeyError, RuntimeError) as error:
+                print(" > Criterion load ignored because of:", error)
+
+        if use_cuda:
+            self.cuda()
+            if criterion is not None:
+                criterion = criterion.cuda()
+
+        if eval:
+            self.eval()
+            assert not self.training
+
+        if not eval:
+            return criterion, state["step"]
+        return criterion
+
+class HifiDecoder(torch.nn.Module):
+    def __init__(
+        self,
+        input_sample_rate=22050,
+        output_sample_rate=24000,
+        output_hop_length=256,
+        ar_mel_length_compression=1024,
+        decoder_input_dim=1024,
+        resblock_type_decoder="1",
+        resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        resblock_kernel_sizes_decoder=[3, 7, 11],
+        upsample_rates_decoder=[8, 8, 2, 2],
+        upsample_initial_channel_decoder=512,
+        upsample_kernel_sizes_decoder=[16, 16, 4, 4],
+        d_vector_dim=512,
+        cond_d_vector_in_each_upsampling_layer=True,
+        speaker_encoder_audio_config={
+            "fft_size": 512,
+            "win_length": 400,
+            "hop_length": 160,
+            "sample_rate": 16000,
+            "preemphasis": 0.97,
+            "num_mels": 64,
+        },
+    ):
+        super().__init__()
+        self.input_sample_rate = input_sample_rate
+        self.output_sample_rate = output_sample_rate
+        self.output_hop_length = output_hop_length
+        self.ar_mel_length_compression = ar_mel_length_compression
+        self.speaker_encoder_audio_config = speaker_encoder_audio_config
+        self.waveform_decoder = HifiganGenerator(
+            decoder_input_dim,
+            1,
+            resblock_type_decoder,
+            resblock_dilation_sizes_decoder,
+            resblock_kernel_sizes_decoder,
+            upsample_kernel_sizes_decoder,
+            upsample_initial_channel_decoder,
+            upsample_rates_decoder,
+            inference_padding=0,
+            cond_channels=d_vector_dim,
+            conv_pre_weight_norm=False,
+            conv_post_weight_norm=False,
+            conv_post_bias=False,
+            cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
+        )
+        self.speaker_encoder = ResNetSpeakerEncoder(
+            input_dim=64,
+            proj_dim=512,
+            log_input=True,
+            use_torch_spec=True,
+            audio_config=speaker_encoder_audio_config,
+        )
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+    def forward(self, latents, g=None):
+        """
+        Args:
+            x (Tensor): feature input tensor (GPT latent).
+            g (Tensor): global conditioning input tensor.
+
+        Returns:
+            Tensor: output waveform.
+
+        Shapes:
+            x: [B, C, T]
+            Tensor: [B, 1, T]
+        """
+
+        z = torch.nn.functional.interpolate(
+            latents.transpose(1, 2),
+            scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
+            mode="linear",
+        ).squeeze(1)
+        # upsample to the right sr
+        if self.output_sample_rate != self.input_sample_rate:
+            z = torch.nn.functional.interpolate(
+                z,
+                scale_factor=[self.output_sample_rate / self.input_sample_rate],
+                mode="linear",
+            ).squeeze(0)
+        o = self.waveform_decoder(z, g=g)
+        return o
+
+    @torch.no_grad()
+    def inference(self, c, g):
+        """
+        Args:
+            x (Tensor): feature input tensor (GPT latent).
+            g (Tensor): global conditioning input tensor.
+
+        Returns:
+            Tensor: output waveform.
+
+        Shapes:
+            x: [B, C, T]
+            Tensor: [B, 1, T]
+        """
+        return self.forward(c, g=g)
+
+    def load_checkpoint(
+        self, checkpoint_path, eval=False
+    ):  # pylint: disable=unused-argument, redefined-builtin
+        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
+        # remove unused keys
+        state = state["model"]
+        states_keys = list(state.keys())
+        for key in states_keys:
+            if "waveform_decoder." not in key and "speaker_encoder." not in key:
+                del state[key]
+
+        self.load_state_dict(state)
+        if eval:
+            self.eval()
+            assert not self.training
+            self.waveform_decoder.remove_weight_norm()
diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py
new file mode 100644
index 00000000..8bdd2291
--- /dev/null
+++ b/TTS/tts/layers/xtts/stream_generator.py
@@ -0,0 +1,1057 @@
+# Adapted from: https://github.com/LowinLi/transformers-stream-generator
+
+from transformers import (
+    GenerationConfig,
+    GenerationMixin,
+    LogitsProcessorList,
+    StoppingCriteriaList,
+    DisjunctiveConstraint,
+    BeamSearchScorer,
+    PhrasalConstraint,
+    ConstrainedBeamSearchScorer,
+    PreTrainedModel,
+)
+import numpy as np
+import random
+import warnings
+import inspect
+from transformers.generation.utils import GenerateOutput, SampleOutput, logger
+import torch
+from typing import Callable, List, Optional, Union
+from torch import nn
+import torch.distributed as dist
+import copy
+
+
+def setup_seed(seed):
+    if seed == -1:
+        return
+    torch.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed_all(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+
+
+class StreamGenerationConfig(GenerationConfig):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.do_stream = kwargs.pop("do_stream", False)
+
+
+class NewGenerationMixin(GenerationMixin):
+    @torch.no_grad()
+    def generate(
+        self,
+        inputs: Optional[torch.Tensor] = None,
+        generation_config: Optional[StreamGenerationConfig] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        prefix_allowed_tokens_fn: Optional[
+            Callable[[int, torch.Tensor], List[int]]
+        ] = None,
+        synced_gpus: Optional[bool] = False,
+        seed=0,
+        **kwargs,
+    ) -> Union[GenerateOutput, torch.LongTensor]:
+        r"""
+
+        Generates sequences of token ids for models with a language modeling head.
+
+        <Tip warning={true}>
+
+        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+        model's default generation configuration. You can override any `generation_config` by passing the corresponding
+        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+        For an overview of generation strategies and code examples, check out the [following
+        guide](./generation_strategies).
+
+        </Tip>
+
+        Parameters:
+            inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+            generation_config (`~generation.GenerationConfig`, *optional*):
+                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+                passed to generate matching the attributes of `generation_config` will override them. If
+                `generation_config` is not provided, the default will be used, which had the following loading
+                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+                default values, whose documentation should be checked to parameterize generation.
+            logits_processor (`LogitsProcessorList`, *optional*):
+                Custom logits processors that complement the default logits processors built from arguments and
+                generation config. If a logit processor is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            stopping_criteria (`StoppingCriteriaList`, *optional*):
+                Custom stopping criteria that complement the default stopping criteria built from arguments and a
+                generation config. If a stopping criteria is passed that is already created with the arguments or a
+                generation config an error is thrown. This feature is intended for advanced users.
+            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
+                If provided, this function constraints the beam search to allowed tokens only at each step. If not
+                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
+                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
+                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
+                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+                Retrieval](https://arxiv.org/abs/2010.00904).
+            synced_gpus (`bool`, *optional*, defaults to `False`):
+                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+            kwargs:
+                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+        Return:
+            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GreedySearchDecoderOnlyOutput`],
+                    - [`~generation.SampleDecoderOnlyOutput`],
+                    - [`~generation.BeamSearchDecoderOnlyOutput`],
+                    - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+                [`~utils.ModelOutput`] types are:
+
+                    - [`~generation.GreedySearchEncoderDecoderOutput`],
+                    - [`~generation.SampleEncoderDecoderOutput`],
+                    - [`~generation.BeamSearchEncoderDecoderOutput`],
+                    - [`~generation.BeamSampleEncoderDecoderOutput`]
+        """
+        #setup_seed(seed)
+        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+        self._validate_model_class()
+
+        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+        if generation_config is None:
+            # legacy: users may modify the model configuration to control generation -- update the generation config
+            # model attribute accordingly, if it was created from the model config
+            if self.generation_config._from_model_config:
+                new_generation_config = StreamGenerationConfig.from_model_config(
+                    self.config
+                )
+                if new_generation_config != self.generation_config:
+                    warnings.warn(
+                        "You have modified the pretrained model configuration to control generation. This is a"
+                        " deprecated strategy to control generation and will be removed soon, in a future version."
+                        " Please use a generation configuration file (see"
+                        " https://huggingface.co/docs/transformers/main_classes/text_generation)"
+                    )
+                    self.generation_config = new_generation_config
+            generation_config = self.generation_config
+
+        generation_config = copy.deepcopy(generation_config)
+        model_kwargs = generation_config.update(
+            **kwargs
+        )  # All unused kwargs must be model kwargs
+        # self._validate_model_kwargs(model_kwargs.copy())
+
+        # 2. Set generation parameters if not already defined
+        logits_processor = (
+            logits_processor if logits_processor is not None else LogitsProcessorList()
+        )
+        stopping_criteria = (
+            stopping_criteria
+            if stopping_criteria is not None
+            else StoppingCriteriaList()
+        )
+
+        if (
+            generation_config.pad_token_id is None
+            and generation_config.eos_token_id is not None
+        ):
+            if model_kwargs.get("attention_mask", None) is None:
+                logger.warning(
+                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
+                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+                )
+            eos_token_id = generation_config.eos_token_id
+            if isinstance(eos_token_id, list):
+                eos_token_id = eos_token_id[0]
+            logger.warning(
+                f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
+            )
+            generation_config.pad_token_id = eos_token_id
+
+        # 3. Define model inputs
+        # inputs_tensor has to be defined
+        # model_input_name is defined if model-specific keyword input is passed
+        # otherwise model_input_name is None
+        # all model-specific keyword inputs are removed from `model_kwargs`
+        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        batch_size = inputs_tensor.shape[0]
+
+        # 4. Define other model kwargs
+        model_kwargs["output_attentions"] = generation_config.output_attentions
+        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+        model_kwargs["use_cache"] = generation_config.use_cache
+
+        accepts_attention_mask = "attention_mask" in set(
+            inspect.signature(self.forward).parameters.keys()
+        )
+        requires_attention_mask = "encoder_outputs" not in model_kwargs
+
+        if (
+            model_kwargs.get("attention_mask", None) is None
+            and requires_attention_mask
+            and accepts_attention_mask
+        ):
+            model_kwargs[
+                "attention_mask"
+            ] = self._prepare_attention_mask_for_generation(
+                inputs_tensor,
+                generation_config.pad_token_id,
+                generation_config.eos_token_id,
+            )
+
+        # decoder-only models should use left-padding for generation
+        if not self.config.is_encoder_decoder:
+            if (
+                generation_config.pad_token_id is not None
+                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
+                > 0
+            ):
+                logger.warning(
+                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
+                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
+                )
+
+        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+            # if model is encoder decoder encoder_outputs are created
+            # and added to `model_kwargs`
+            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+                inputs_tensor, model_kwargs, model_input_name
+            )
+
+        # 5. Prepare `input_ids` which will be used for auto-regressive generation
+        if self.config.is_encoder_decoder:
+            input_ids = self._prepare_decoder_input_ids_for_generation(
+                batch_size,
+                decoder_start_token_id=generation_config.decoder_start_token_id,
+                bos_token_id=generation_config.bos_token_id,
+                model_kwargs=model_kwargs,
+                device=inputs_tensor.device,
+            )
+        else:
+            # if decoder-only then inputs_tensor has to be `input_ids`
+            input_ids = inputs_tensor
+
+        # 6. Prepare `max_length` depending on other stopping criteria.
+        input_ids_seq_length = input_ids.shape[-1]
+        has_default_max_length = (
+            kwargs.get("max_length") is None
+            and generation_config.max_length is not None
+        )
+        if has_default_max_length and generation_config.max_new_tokens is None:
+            warnings.warn(
+                "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
+                f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
+                " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
+                " recommend using `max_new_tokens` to control the maximum length of the generation.",
+                UserWarning,
+            )
+        elif has_default_max_length and generation_config.max_new_tokens is not None:
+            generation_config.max_length = (
+                generation_config.max_new_tokens + input_ids_seq_length
+            )
+        elif (
+            not has_default_max_length and generation_config.max_new_tokens is not None
+        ):
+            raise ValueError(
+                "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
+                " limit to the generated output length. Remove one of those arguments. Please refer to the"
+                " documentation for more information. "
+                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+            )
+
+        if (
+            generation_config.min_length is not None
+            and generation_config.min_length > generation_config.max_length
+        ):
+            raise ValueError(
+                f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+                f" the maximum length ({generation_config.max_length})"
+            )
+        if input_ids_seq_length >= generation_config.max_length:
+            input_ids_string = (
+                "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+            )
+            logger.warning(
+                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+                " increasing `max_new_tokens`."
+            )
+
+        # 7. determine generation mode
+        is_constraint_gen_mode = (
+            generation_config.constraints is not None
+            or generation_config.force_words_ids is not None
+        )
+
+        is_contrastive_search_gen_mode = (
+            generation_config.top_k is not None
+            and generation_config.top_k > 1
+            and generation_config.do_sample is False
+            and generation_config.penalty_alpha is not None
+            and generation_config.penalty_alpha > 0
+        )
+
+        is_greedy_gen_mode = (
+            (generation_config.num_beams == 1)
+            and (generation_config.num_beam_groups == 1)
+            and generation_config.do_sample is False
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+        is_sample_gen_mode = (
+            (generation_config.num_beams == 1)
+            and (generation_config.num_beam_groups == 1)
+            and generation_config.do_sample is True
+            and generation_config.do_stream is False
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+        is_sample_gen_stream_mode = (
+            (generation_config.num_beams == 1)
+            and (generation_config.num_beam_groups == 1)
+            and generation_config.do_stream is True
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+        is_beam_gen_mode = (
+            (generation_config.num_beams > 1)
+            and (generation_config.num_beam_groups == 1)
+            and generation_config.do_sample is False
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+        is_beam_sample_gen_mode = (
+            (generation_config.num_beams > 1)
+            and (generation_config.num_beam_groups == 1)
+            and generation_config.do_sample is True
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+        is_group_beam_gen_mode = (
+            (generation_config.num_beams > 1)
+            and (generation_config.num_beam_groups > 1)
+            and not is_constraint_gen_mode
+            and not is_contrastive_search_gen_mode
+        )
+
+        if generation_config.num_beam_groups > generation_config.num_beams:
+            raise ValueError(
+                "`num_beam_groups` has to be smaller or equal to `num_beams`"
+            )
+        if is_group_beam_gen_mode and generation_config.do_sample is True:
+            raise ValueError(
+                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
+            )
+
+        if self.device.type != input_ids.device.type:
+            warnings.warn(
+                "You are calling .generate() with the `input_ids` being on a device type different"
+                f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
+                f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
+                " Please make sure that you have put `input_ids` to the"
+                f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
+                " running `.generate()`.",
+                UserWarning,
+            )
+        # 8. prepare distribution pre_processing samplers
+        logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids_seq_length,
+            encoder_input_ids=inputs_tensor,
+            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+            logits_processor=logits_processor,
+        )
+
+        # 9. prepare stopping criteria
+        stopping_criteria = self._get_stopping_criteria(
+            generation_config=generation_config, stopping_criteria=stopping_criteria
+        )
+        # 10. go into different generation modes
+        if is_greedy_gen_mode:
+            if generation_config.num_return_sequences > 1:
+                raise ValueError(
+                    f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+                    " greedy search."
+                )
+
+            # 11. run greedy search
+            return self.greedy_search(
+                input_ids,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+        elif is_contrastive_search_gen_mode:
+            if generation_config.num_return_sequences > 1:
+                raise ValueError(
+                    f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+                    " contrastive search."
+                )
+
+            return self.contrastive_search(
+                input_ids,
+                top_k=generation_config.top_k,
+                penalty_alpha=generation_config.penalty_alpha,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+        elif is_sample_gen_mode:
+            # 11. prepare logits warper
+            logits_warper = self._get_logits_warper(generation_config)
+
+            # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_return_sequences,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+
+            # 13. run sample
+            return self.sample(
+                input_ids,
+                logits_processor=logits_processor,
+                logits_warper=logits_warper,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+        elif is_sample_gen_stream_mode:
+            # 11. prepare logits warper
+            logits_warper = self._get_logits_warper(generation_config)
+
+            # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_return_sequences,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+
+            # 13. run sample
+            return self.sample_stream(
+                input_ids,
+                logits_processor=logits_processor,
+                logits_warper=logits_warper,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+        elif is_beam_gen_mode:
+            if generation_config.num_return_sequences > generation_config.num_beams:
+                raise ValueError(
+                    "`num_return_sequences` has to be smaller or equal to `num_beams`."
+                )
+
+            if stopping_criteria.max_length is None:
+                raise ValueError(
+                    "`max_length` needs to be a stopping_criteria for now."
+                )
+
+            # 11. prepare beam search scorer
+            beam_scorer = BeamSearchScorer(
+                batch_size=batch_size,
+                num_beams=generation_config.num_beams,
+                device=inputs_tensor.device,
+                length_penalty=generation_config.length_penalty,
+                do_early_stopping=generation_config.early_stopping,
+                num_beam_hyps_to_keep=generation_config.num_return_sequences,
+            )
+            # 12. interleave input_ids with `num_beams` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_beams,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+            # 13. run beam search
+            return self.beam_search(
+                input_ids,
+                beam_scorer,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+        elif is_beam_sample_gen_mode:
+            # 11. prepare logits warper
+            logits_warper = self._get_logits_warper(generation_config)
+
+            if stopping_criteria.max_length is None:
+                raise ValueError(
+                    "`max_length` needs to be a stopping_criteria for now."
+                )
+            # 12. prepare beam search scorer
+            beam_scorer = BeamSearchScorer(
+                batch_size=batch_size * generation_config.num_return_sequences,
+                num_beams=generation_config.num_beams,
+                device=inputs_tensor.device,
+                length_penalty=generation_config.length_penalty,
+                do_early_stopping=generation_config.early_stopping,
+            )
+
+            # 13. interleave input_ids with `num_beams` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_beams
+                * generation_config.num_return_sequences,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+
+            # 14. run beam sample
+            return self.beam_sample(
+                input_ids,
+                beam_scorer,
+                logits_processor=logits_processor,
+                logits_warper=logits_warper,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+        elif is_group_beam_gen_mode:
+            if generation_config.num_return_sequences > generation_config.num_beams:
+                raise ValueError(
+                    "`num_return_sequences` has to be smaller or equal to `num_beams`."
+                )
+
+            if generation_config.num_beams % generation_config.num_beam_groups != 0:
+                raise ValueError(
+                    "`num_beams` should be divisible by `num_beam_groups` for group beam search."
+                )
+
+            if stopping_criteria.max_length is None:
+                raise ValueError(
+                    "`max_length` needs to be a stopping_criteria for now."
+                )
+
+            has_default_typical_p = (
+                kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
+            )
+            if not has_default_typical_p:
+                raise ValueError(
+                    "Decoder argument `typical_p` is not supported with beam groups."
+                )
+
+            # 11. prepare beam search scorer
+            beam_scorer = BeamSearchScorer(
+                batch_size=batch_size,
+                num_beams=generation_config.num_beams,
+                max_length=stopping_criteria.max_length,
+                device=inputs_tensor.device,
+                length_penalty=generation_config.length_penalty,
+                do_early_stopping=generation_config.early_stopping,
+                num_beam_hyps_to_keep=generation_config.num_return_sequences,
+                num_beam_groups=generation_config.num_beam_groups,
+            )
+            # 12. interleave input_ids with `num_beams` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_beams,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+            # 13. run beam search
+            return self.group_beam_search(
+                input_ids,
+                beam_scorer,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+        elif is_constraint_gen_mode:
+            if generation_config.num_return_sequences > generation_config.num_beams:
+                raise ValueError(
+                    "`num_return_sequences` has to be smaller or equal to `num_beams`."
+                )
+
+            if stopping_criteria.max_length is None:
+                raise ValueError(
+                    "`max_length` needs to be a stopping_criteria for now."
+                )
+
+            if generation_config.num_beams <= 1:
+                raise ValueError(
+                    "`num_beams` needs to be greater than 1 for constrained generation."
+                )
+
+            if generation_config.do_sample:
+                raise ValueError(
+                    "`do_sample` needs to be false for constrained generation."
+                )
+
+            if (
+                generation_config.num_beam_groups is not None
+                and generation_config.num_beam_groups > 1
+            ):
+                raise ValueError(
+                    "`num_beam_groups` not supported yet for constrained generation."
+                )
+
+            final_constraints = []
+            if generation_config.constraints is not None:
+                final_constraints = generation_config.constraints
+
+            if generation_config.force_words_ids is not None:
+
+                def typeerror():
+                    raise ValueError(
+                        "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
+                        f"of positive integers, but is {generation_config.force_words_ids}."
+                    )
+
+                if (
+                    not isinstance(generation_config.force_words_ids, list)
+                    or len(generation_config.force_words_ids) == 0
+                ):
+                    typeerror()
+
+                for word_ids in generation_config.force_words_ids:
+                    if isinstance(word_ids[0], list):
+                        if not isinstance(word_ids, list) or len(word_ids) == 0:
+                            typeerror()
+                        if any(
+                            not isinstance(token_ids, list) for token_ids in word_ids
+                        ):
+                            typeerror()
+                        if any(
+                            any(
+                                (not isinstance(token_id, int) or token_id < 0)
+                                for token_id in token_ids
+                            )
+                            for token_ids in word_ids
+                        ):
+                            typeerror()
+
+                        constraint = DisjunctiveConstraint(word_ids)
+                    else:
+                        if not isinstance(word_ids, list) or len(word_ids) == 0:
+                            typeerror()
+                        if any(
+                            (not isinstance(token_id, int) or token_id < 0)
+                            for token_id in word_ids
+                        ):
+                            typeerror()
+
+                        constraint = PhrasalConstraint(word_ids)
+                    final_constraints.append(constraint)
+
+            # 11. prepare beam search scorer
+            constrained_beam_scorer = ConstrainedBeamSearchScorer(
+                constraints=final_constraints,
+                batch_size=batch_size,
+                num_beams=generation_config.num_beams,
+                device=inputs_tensor.device,
+                length_penalty=generation_config.length_penalty,
+                do_early_stopping=generation_config.early_stopping,
+                num_beam_hyps_to_keep=generation_config.num_return_sequences,
+            )
+            # 12. interleave input_ids with `num_beams` additional sequences per batch
+            input_ids, model_kwargs = self._expand_inputs_for_generation(
+                input_ids=input_ids,
+                expand_size=generation_config.num_beams,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+                **model_kwargs,
+            )
+            # 13. run beam search
+            return self.constrained_beam_search(
+                input_ids,
+                constrained_beam_scorer=constrained_beam_scorer,
+                logits_processor=logits_processor,
+                stopping_criteria=stopping_criteria,
+                pad_token_id=generation_config.pad_token_id,
+                eos_token_id=generation_config.eos_token_id,
+                output_scores=generation_config.output_scores,
+                return_dict_in_generate=generation_config.return_dict_in_generate,
+                synced_gpus=synced_gpus,
+                **model_kwargs,
+            )
+
+    @torch.no_grad()
+    def sample_stream(
+        self,
+        input_ids: torch.LongTensor,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        stopping_criteria: Optional[StoppingCriteriaList] = None,
+        logits_warper: Optional[LogitsProcessorList] = None,
+        max_length: Optional[int] = None,
+        pad_token_id: Optional[int] = None,
+        eos_token_id: Optional[Union[int, List[int]]] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_scores: Optional[bool] = None,
+        return_dict_in_generate: Optional[bool] = None,
+        synced_gpus: Optional[bool] = False,
+        **model_kwargs,
+    ) -> Union[SampleOutput, torch.LongTensor]:
+        r"""
+        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+        <Tip warning={true}>
+
+        In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
+        For an overview of generation strategies and code examples, check the [following
+        guide](./generation_strategies).
+
+        </Tip>
+
+        Parameters:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                The sequence used as a prompt for the generation.
+            logits_processor (`LogitsProcessorList`, *optional*):
+                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+                used to modify the prediction scores of the language modeling head applied at each generation step.
+            stopping_criteria (`StoppingCriteriaList`, *optional*):
+                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+                used to tell if the generation loop should stop.
+            logits_warper (`LogitsProcessorList`, *optional*):
+                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+                to warp the prediction score distribution of the language modeling head applied before multinomial
+                sampling at each generation step.
+            max_length (`int`, *optional*, defaults to 20):
+                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
+                tokens. The maximum length of the sequence to be generated.
+            pad_token_id (`int`, *optional*):
+                The id of the *padding* token.
+            eos_token_id (`int`, *optional*):
+                The id of the *end-of-sequence* token.
+            output_attentions (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more details.
+            output_hidden_states (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more details.
+            output_scores (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+            synced_gpus (`bool`, *optional*, defaults to `False`):
+                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+            model_kwargs:
+                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+                an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+        Return:
+            [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
+            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+            [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+            `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
+            `model.config.is_encoder_decoder=True`.
+
+        Examples:
+
+        ```python
+        >>> from transformers import (
+        ...     AutoTokenizer,
+        ...     AutoModelForCausalLM,
+        ...     LogitsProcessorList,
+        ...     MinLengthLogitsProcessor,
+        ...     TopKLogitsWarper,
+        ...     TemperatureLogitsWarper,
+        ...     StoppingCriteriaList,
+        ...     MaxLengthCriteria,
+        ... )
+        >>> import torch
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+        >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+        >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
+        >>> model.config.pad_token_id = model.config.eos_token_id
+        >>> model.generation_config.pad_token_id = model.config.eos_token_id
+
+        >>> input_prompt = "Today is a beautiful day, and"
+        >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
+
+        >>> # instantiate logits processors
+        >>> logits_processor = LogitsProcessorList(
+        ...     [
+        ...         MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
+        ...     ]
+        ... )
+        >>> # instantiate logits processors
+        >>> logits_warper = LogitsProcessorList(
+        ...     [
+        ...         TopKLogitsWarper(50),
+        ...         TemperatureLogitsWarper(0.7),
+        ...     ]
+        ... )
+
+        >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+
+        >>> torch.manual_seed(0)  # doctest: +IGNORE_RESULT
+        >>> outputs = model.sample(
+        ...     input_ids,
+        ...     logits_processor=logits_processor,
+        ...     logits_warper=logits_warper,
+        ...     stopping_criteria=stopping_criteria,
+        ... )
+
+        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
+        ```"""
+        # init values
+        logits_processor = (
+            logits_processor if logits_processor is not None else LogitsProcessorList()
+        )
+        stopping_criteria = (
+            stopping_criteria
+            if stopping_criteria is not None
+            else StoppingCriteriaList()
+        )
+        if max_length is not None:
+            warnings.warn(
+                "`max_length` is deprecated in this function, use"
+                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+                UserWarning,
+            )
+            stopping_criteria = validate_stopping_criteria(
+                stopping_criteria, max_length
+            )
+        logits_warper = (
+            logits_warper if logits_warper is not None else LogitsProcessorList()
+        )
+        pad_token_id = (
+            pad_token_id
+            if pad_token_id is not None
+            else self.generation_config.pad_token_id
+        )
+        eos_token_id = (
+            eos_token_id
+            if eos_token_id is not None
+            else self.generation_config.eos_token_id
+        )
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        output_scores = (
+            output_scores
+            if output_scores is not None
+            else self.generation_config.output_scores
+        )
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.generation_config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.generation_config.output_hidden_states
+        )
+        return_dict_in_generate = (
+            return_dict_in_generate
+            if return_dict_in_generate is not None
+            else self.generation_config.return_dict_in_generate
+        )
+
+        # init attention / hidden states / scores tuples
+        scores = () if (return_dict_in_generate and output_scores) else None
+        decoder_attentions = (
+            () if (return_dict_in_generate and output_attentions) else None
+        )
+        cross_attentions = (
+            () if (return_dict_in_generate and output_attentions) else None
+        )
+        decoder_hidden_states = (
+            () if (return_dict_in_generate and output_hidden_states) else None
+        )
+
+        # keep track of which sequences are already finished
+        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+        this_peer_finished = False  # used by synced_gpus only
+        # auto-regressive generation
+        while True:
+            if synced_gpus:
+                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+                # The following logic allows an early break if all peers finished generating their sequence
+                this_peer_finished_flag = torch.tensor(
+                    0.0 if this_peer_finished else 1.0
+                ).to(input_ids.device)
+                # send 0.0 if we finished, 1.0 otherwise
+                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
+                # did all peers finish? the reduced sum will be 0.0 then
+                if this_peer_finished_flag.item() == 0.0:
+                    break
+
+            # prepare model inputs
+            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+            # forward pass to get next token
+            outputs = self(
+                **model_inputs,
+                return_dict=True,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+            )
+
+            if synced_gpus and this_peer_finished:
+                continue  # don't waste resources running the code we don't need
+
+            next_token_logits = outputs.logits[:, -1, :]
+
+            # pre-process distribution
+            next_token_scores = logits_processor(input_ids, next_token_logits)
+            next_token_scores = logits_warper(input_ids, next_token_scores)
+
+            # Store scores, attentions and hidden_states when required
+            if return_dict_in_generate:
+                if output_scores:
+                    scores += (next_token_scores,)
+                if output_attentions:
+                    decoder_attentions += (
+                        (outputs.decoder_attentions,)
+                        if self.config.is_encoder_decoder
+                        else (outputs.attentions,)
+                    )
+                    if self.config.is_encoder_decoder:
+                        cross_attentions += (outputs.cross_attentions,)
+
+                if output_hidden_states:
+                    decoder_hidden_states += (
+                        (outputs.decoder_hidden_states,)
+                        if self.config.is_encoder_decoder
+                        else (outputs.hidden_states,)
+                    )
+
+            # sample
+            probs = nn.functional.softmax(next_token_scores, dim=-1)
+            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+            # finished sentences should have their next token be a padding token
+            if eos_token_id is not None:
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+                    )
+                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+                    1 - unfinished_sequences
+                )
+            yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
+            # update generated ids, model inputs, and length for next step
+            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+            model_kwargs = self._update_model_kwargs_for_generation(
+                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+            )
+
+            # if eos_token was found in one sentence, set sentence to finished
+            if eos_token_id is not None:
+                unfinished_sequences = unfinished_sequences.mul(
+                    (sum(next_tokens != i for i in eos_token_id)).long()
+                )
+
+            # stop when each sentence is finished, or if we exceed the maximum length
+            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+                if not synced_gpus:
+                    break
+                else:
+                    this_peer_finished = True
+
+
+def init_stream_support():
+    """Overload PreTrainedModel for streaming."""
+    PreTrainedModel.generate_stream = NewGenerationMixin.generate
+    PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+
+
+if __name__ == "__main__":
+    from transformers import PreTrainedModel
+    from transformers import AutoTokenizer, AutoModelForCausalLM
+
+    PreTrainedModel.generate = NewGenerationMixin.generate
+    PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
+    model = AutoModelForCausalLM.from_pretrained(
+        "bigscience/bloom-560m", torch_dtype=torch.float16
+    )
+
+    tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
+    model = model.to("cuda:0")
+    model = model.eval()
+    prompt_text = "hello? \n"
+    input_ids = tokenizer(
+        prompt_text, return_tensors="pt", add_special_tokens=False
+    ).input_ids
+    input_ids = input_ids.to("cuda:0")
+
+    with torch.no_grad():
+        result = model.generate(
+            input_ids,
+            max_new_tokens=200,
+            do_sample=True,
+            top_k=30,
+            top_p=0.85,
+            temperature=0.35,
+            repetition_penalty=1.2,
+            early_stopping=True,
+            seed=0,
+        )
+        print(tokenizer.decode(result, skip_special_tokens=True))
+        generator = model.generate(
+            input_ids,
+            max_new_tokens=200,
+            do_sample=True,
+            top_k=30,
+            top_p=0.85,
+            temperature=0.35,
+            repetition_penalty=1.2,
+            early_stopping=True,
+            seed=0,
+            do_stream=True,
+        )
+        stream_result = ""
+        for x in generator:
+            chunk = tokenizer.decode(x, skip_special_tokens=True)
+            stream_result += chunk
+        print(stream_result)
diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py
index 8dd81fac..a2795289 100644
--- a/TTS/tts/layers/xtts/tokenizer.py
+++ b/TTS/tts/layers/xtts/tokenizer.py
@@ -224,7 +224,10 @@ class VoiceBpeTokenizer:
             txt = " ".join([result["kana"] for result in results])
             txt = basic_cleaners(txt)
         elif lang == "en":
+            if txt[:4] == "[en]":
+                txt = txt[4:]
             txt = english_cleaners(txt)
+            txt = "[en]" + txt
         elif lang == "ar":
             txt = arabic_cleaners(txt)
         elif lang == "zh-cn":
diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py
index a23a0f5f..2b480744 100644
--- a/TTS/tts/models/xtts.py
+++ b/TTS/tts/models/xtts.py
@@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
 from TTS.tts.layers.xtts.gpt import GPT
 from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
 from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
+from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
+from TTS.tts.layers.xtts.stream_generator import init_stream_support
 from TTS.tts.models.base_tts import BaseTTS
 from TTS.utils.io import load_fsspec
 
+init_stream_support()
 
 def load_audio(audiopath, sr=22050):
     """
@@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
     Args:
         gpt_batch_size (int): The size of the auto-regressive batch.
         enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
-        lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
         kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
         gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
         clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
         decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
         num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
-        vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
+        use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
 
         For GPT model:
         ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@@ -231,12 +233,12 @@ class XttsArgs(Coqpit):
 
     gpt_batch_size: int = 1
     enable_redaction: bool = False
-    lazy_load: bool = True
     kv_cache: bool = True
     gpt_checkpoint: str = None
     clvp_checkpoint: str = None
     decoder_checkpoint: str = None
     num_chars: int = 255
+    use_hifigan: bool = True
 
     # XTTS GPT Encoder params
     tokenizer_file: str = ""
@@ -266,6 +268,15 @@ class XttsArgs(Coqpit):
     diff_layer_drop: int = 0
     diff_unconditioned_percentage: int = 0
 
+    # HifiGAN Decoder params
+    input_sample_rate: int = 22050
+    output_sample_rate: int = 24000
+    output_hop_length: int = 256
+    ar_mel_length_compression: int = 1024
+    decoder_input_dim: int = 1024
+    d_vector_dim: int = 512
+    cond_d_vector_in_each_upsampling_layer: bool = True
+
     # constants
     duration_const: int = 102400
 
@@ -285,7 +296,6 @@ class Xtts(BaseTTS):
 
     def __init__(self, config: Coqpit):
         super().__init__(config, ap=None, tokenizer=None)
-        self.lazy_load = self.args.lazy_load
         self.mel_stats_path = None
         self.config = config
         self.gpt_checkpoint = self.args.gpt_checkpoint
@@ -295,7 +305,6 @@ class Xtts(BaseTTS):
 
         self.tokenizer = VoiceBpeTokenizer()
         self.gpt = None
-        self.diffusion_decoder = None
         self.init_models()
         self.register_buffer("mel_stats", torch.ones(80))
 
@@ -322,40 +331,39 @@ class Xtts(BaseTTS):
                 stop_audio_token=self.args.gpt_stop_audio_token,
             )
 
-        self.diffusion_decoder = DiffusionTts(
-            model_channels=self.args.diff_model_channels,
-            num_layers=self.args.diff_num_layers,
-            in_channels=self.args.diff_in_channels,
-            out_channels=self.args.diff_out_channels,
-            in_latent_channels=self.args.diff_in_latent_channels,
-            in_tokens=self.args.diff_in_tokens,
-            dropout=self.args.diff_dropout,
-            use_fp16=self.args.diff_use_fp16,
-            num_heads=self.args.diff_num_heads,
-            layer_drop=self.args.diff_layer_drop,
-            unconditioned_percentage=self.args.diff_unconditioned_percentage,
-        )
 
-        self.vocoder = UnivNetGenerator()
+        if self.args.use_hifigan:
+            self.hifigan_decoder = HifiDecoder(
+                input_sample_rate=self.args.input_sample_rate,
+                output_sample_rate=self.args.output_sample_rate,
+                output_hop_length=self.args.output_hop_length,
+                ar_mel_length_compression=self.args.ar_mel_length_compression,
+                decoder_input_dim=self.args.decoder_input_dim,
+                d_vector_dim=self.args.d_vector_dim,
+                cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
+            )
+
+        else:
+            self.diffusion_decoder = DiffusionTts(
+                model_channels=self.args.diff_model_channels,
+                num_layers=self.args.diff_num_layers,
+                in_channels=self.args.diff_in_channels,
+                out_channels=self.args.diff_out_channels,
+                in_latent_channels=self.args.diff_in_latent_channels,
+                in_tokens=self.args.diff_in_tokens,
+                dropout=self.args.diff_dropout,
+                use_fp16=self.args.diff_use_fp16,
+                num_heads=self.args.diff_num_heads,
+                layer_drop=self.args.diff_layer_drop,
+                unconditioned_percentage=self.args.diff_unconditioned_percentage,
+            )
+            self.vocoder = UnivNetGenerator()
 
     @property
     def device(self):
         return next(self.parameters()).device
 
-    @contextmanager
-    def lazy_load_model(self, model):
-        """Context to load a model on demand.
-
-        Args:
-            model (nn.Module): The model to be loaded.
-        """
-        if self.lazy_load:
-            yield model
-        else:
-            m = model.to(self.device)
-            yield m
-            m = model.cpu()
-
+    @torch.inference_mode()
     def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
         """Compute the conditioning latents for the GPT model from the given audio.
 
@@ -370,6 +378,7 @@ class Xtts(BaseTTS):
         cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
         return cond_latent.transpose(1, 2)
 
+    @torch.inference_mode()
     def get_diffusion_cond_latents(
         self,
         audio_path,
@@ -389,20 +398,33 @@ class Xtts(BaseTTS):
             )
             diffusion_conds.append(cond_mel)
         diffusion_conds = torch.stack(diffusion_conds, dim=1)
-        with self.lazy_load_model(self.diffusion_decoder) as diffusion:
-            diffusion_latent = diffusion.get_conditioning(diffusion_conds)
+        diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
         return diffusion_latent
 
+    @torch.inference_mode()
+    def get_speaker_embedding(
+        self,
+        audio_path
+    ):
+        audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
+        speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
+            audio.to(self.device), l2_norm=True
+        ).unsqueeze(-1).to(self.device)
+        return speaker_embedding
+
     def get_conditioning_latents(
         self,
         audio_path,
         gpt_cond_len=3,
-    ):
+    ):  
+        speaker_embedding = None
+        diffusion_cond_latents = None
+        if self.args.use_hifigan:
+            speaker_embedding = self.get_speaker_embedding(audio_path)
+        else:
+            diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
         gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len)  # [1, 1024, T]
-        diffusion_cond_latents = self.get_diffusion_cond_latents(
-            audio_path,
-        )
-        return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
+        return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
 
     def synthesize(self, text, config, speaker_wav, language, **kwargs):
         """Synthesize speech with the given input text.
@@ -447,10 +469,10 @@ class Xtts(BaseTTS):
             "decoder_sampler": config.decoder_sampler,
         }
         settings.update(kwargs)  # allow overriding of preset settings with kwargs
-        return self.inference(text, ref_audio_path, language, **settings)
+        return self.full_inference(text, ref_audio_path, language, **settings)
 
-    @torch.no_grad()
-    def inference(
+    @torch.inference_mode()
+    def full_inference(
         self,
         text,
         ref_audio_path,
@@ -525,6 +547,54 @@ class Xtts(BaseTTS):
             Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
             Sample rate is 24kHz.
         """
+        (
+            gpt_cond_latent,
+            diffusion_conditioning,
+            speaker_embedding
+        ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
+        return self.inference(
+            text,
+            language,
+            gpt_cond_latent,
+            speaker_embedding,
+            diffusion_conditioning,
+            temperature=temperature,
+            length_penalty=length_penalty,
+            repetition_penalty=repetition_penalty,
+            top_k=top_k,
+            top_p=top_p,
+            do_sample=do_sample,
+            decoder_iterations=decoder_iterations,
+            cond_free=cond_free,
+            cond_free_k=cond_free_k,
+            diffusion_temperature=diffusion_temperature,
+            decoder_sampler=decoder_sampler,
+            **hf_generate_kwargs,
+        )
+    
+    @torch.inference_mode()
+    def inference(
+        self,
+        text,
+        language,
+        gpt_cond_latent,
+        speaker_embedding,
+        diffusion_conditioning,
+        # GPT inference
+        temperature=0.65,
+        length_penalty=1,
+        repetition_penalty=2.0,
+        top_k=50,
+        top_p=0.85,
+        do_sample=True,
+        # Decoder inference
+        decoder_iterations=100,
+        cond_free=True,
+        cond_free_k=2,
+        diffusion_temperature=1.0,
+        decoder_sampler="ddim",
+        **hf_generate_kwargs,
+    ):
         text = f"[{language}]{text.strip().lower()}"
         text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
 
@@ -532,74 +602,147 @@ class Xtts(BaseTTS):
             text_tokens.shape[-1] < self.args.gpt_max_text_tokens
         ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
 
-        (
-            gpt_cond_latent,
-            diffusion_conditioning,
-        ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
-
-        diffuser = load_discrete_vocoder_diffuser(
-            desired_diffusion_steps=decoder_iterations,
-            cond_free=cond_free,
-            cond_free_k=cond_free_k,
-            sampler=decoder_sampler,
-        )
+        if not self.args.use_hifigan:
+            diffuser = load_discrete_vocoder_diffuser(
+                desired_diffusion_steps=decoder_iterations,
+                cond_free=cond_free,
+                cond_free_k=cond_free_k,
+                sampler=decoder_sampler,
+            )
 
         with torch.no_grad():
-            self.gpt = self.gpt.to(self.device)
-            with self.lazy_load_model(self.gpt) as gpt:
-                gpt_codes = gpt.generate(
-                    cond_latents=gpt_cond_latent,
-                    text_inputs=text_tokens,
-                    input_tokens=None,
-                    do_sample=do_sample,
-                    top_p=top_p,
-                    top_k=top_k,
-                    temperature=temperature,
-                    num_return_sequences=self.gpt_batch_size,
-                    length_penalty=length_penalty,
-                    repetition_penalty=repetition_penalty,
-                    output_attentions=False,
-                    **hf_generate_kwargs,
-                )
-
-            with self.lazy_load_model(self.gpt) as gpt:
-                expected_output_len = torch.tensor(
-                    [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
-                )
-                text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
-                gpt_latents = gpt(
-                    text_tokens,
-                    text_len,
-                    gpt_codes,
-                    expected_output_len,
-                    cond_latents=gpt_cond_latent,
-                    return_attentions=False,
-                    return_latent=True,
-                )
-                silence_token = 83
-                ctokens = 0
-                for k in range(gpt_codes.shape[-1]):
-                    if gpt_codes[0, k] == silence_token:
-                        ctokens += 1
-                    else:
-                        ctokens = 0
-                    if ctokens > 8:
-                        gpt_latents = gpt_latents[:, :k]
-                        break
-
-            with self.lazy_load_model(self.diffusion_decoder) as diffusion:
+            gpt_codes = self.gpt.generate(
+                cond_latents=gpt_cond_latent,
+                text_inputs=text_tokens,
+                input_tokens=None,
+                do_sample=do_sample,
+                top_p=top_p,
+                top_k=top_k,
+                temperature=temperature,
+                num_return_sequences=self.gpt_batch_size,
+                length_penalty=length_penalty,
+                repetition_penalty=repetition_penalty,
+                output_attentions=False,
+                **hf_generate_kwargs,
+            )
+            expected_output_len = torch.tensor(
+                [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
+            )
+            text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
+            gpt_latents = self.gpt(
+                text_tokens,
+                text_len,
+                gpt_codes,
+                expected_output_len,
+                cond_latents=gpt_cond_latent,
+                return_attentions=False,
+                return_latent=True,
+            )
+            silence_token = 83
+            ctokens = 0
+            for k in range(gpt_codes.shape[-1]):
+                if gpt_codes[0, k] == silence_token:
+                    ctokens += 1
+                else:
+                    ctokens = 0
+                if ctokens > 8:
+                    gpt_latents = gpt_latents[:, :k]
+                    break
+            
+            if self.args.use_hifigan:
+                wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
+            else:
                 mel = do_spectrogram_diffusion(
-                    diffusion,
+                    self.diffusion_decoder,
                     diffuser,
                     gpt_latents,
                     diffusion_conditioning,
                     temperature=diffusion_temperature,
                 )
-            with self.lazy_load_model(self.vocoder) as vocoder:
-                wav = vocoder.inference(mel)
+                wav = self.vocoder.inference(mel)
 
         return {"wav": wav.cpu().numpy().squeeze()}
 
+    def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
+        """Handle chunk formatting in streaming mode"""
+        wav_chunk = wav_gen[:-overlap_len]
+        if wav_gen_prev is not None:
+            wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
+        if wav_overlap is not None:
+            crossfade_wav = wav_chunk[:overlap_len]
+            crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
+            wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
+            wav_chunk[:overlap_len] += crossfade_wav
+        wav_overlap = wav_gen[-overlap_len:]
+        wav_gen_prev = wav_gen
+        return wav_chunk, wav_gen_prev, wav_overlap
+
+    @torch.inference_mode()
+    def inference_stream(
+        self,
+        text,
+        language,
+        gpt_cond_latent,
+        speaker_embedding,
+        # Streaming
+        stream_chunk_size=20,
+        overlap_wav_len=1024,
+        # GPT inference
+        temperature=0.65,
+        length_penalty=1,
+        repetition_penalty=2.0,
+        top_k=50,
+        top_p=0.85,
+        do_sample=True,
+        # Decoder inference
+        **hf_generate_kwargs,
+    ):
+        assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
+        text = f"[{language}]{text.strip().lower()}"
+        text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
+
+        fake_inputs = self.gpt.compute_embeddings(
+            gpt_cond_latent.to(self.device),
+            text_tokens,
+        )
+        gpt_generator = self.gpt.get_generator(
+            fake_inputs=fake_inputs,
+            top_k=top_k,
+            top_p=top_p,
+            temperature=temperature,
+            do_sample=do_sample,
+            num_beams=1,
+            num_return_sequences=1,
+            length_penalty=float(length_penalty),
+            repetition_penalty=float(repetition_penalty),
+            output_attentions=False,
+            output_hidden_states=True,
+            **hf_generate_kwargs,
+        )
+
+        last_tokens = []
+        all_latents = []
+        wav_gen_prev = None
+        wav_overlap = None
+        is_end = False
+
+        while not is_end:
+            try:
+                x, latent = next(gpt_generator)
+                last_tokens += [x]
+                all_latents += [latent]
+            except StopIteration:
+                is_end = True
+
+            if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
+                gpt_latents = torch.cat(all_latents, dim=0)[None, :]
+                wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
+                wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
+                    wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
+                )
+                last_tokens = []
+                yield wav_chunk
+
     def forward(self):
         raise NotImplementedError("XTTS Training is not implemented")
 
@@ -616,7 +759,14 @@ class Xtts(BaseTTS):
         super().eval()
 
     def load_checkpoint(
-        self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
+        self,
+        config,
+        checkpoint_dir=None,
+        checkpoint_path=None, 
+        vocab_path=None,
+        eval=True,
+        strict=True,
+        use_deepspeed=False,
     ):
         """
         Loads a checkpoint from disk and initializes the model's state and tokenizer.
@@ -626,7 +776,7 @@ class Xtts(BaseTTS):
             checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
             checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
             vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
-            eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
+            eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
             strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
 
         Returns:
@@ -636,19 +786,26 @@ class Xtts(BaseTTS):
         model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
         vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
 
-        if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
-            self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
+        if os.path.exists(vocab_path):
+            self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
 
         self.init_models()
         if eval:
             self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
-        self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
+
+        checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
+        ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
+        for key in list(checkpoint.keys()):
+            if key.split(".")[0] in ignore_keys:
+                del checkpoint[key]
+        self.load_state_dict(checkpoint, strict=strict)
 
         if eval:
-            self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
+            if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
+            if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
+            if hasattr(self, "vocoder"): self.vocoder.eval()
+            self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
             self.gpt.eval()
-            self.diffusion_decoder.eval()
-            self.vocoder.eval()
 
     def train_step(self):
         raise NotImplementedError("XTTS Training is not implemented")
diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md
index 85a3afba..ff6bcf97 100644
--- a/docs/source/models/xtts.md
+++ b/docs/source/models/xtts.md
@@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
 Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
 You can also mail us at info@coqui.ai.
 
-Using 🐸TTS API:
+### Inference
+#### 🐸TTS API
 
 ```python
 from TTS.api import TTS
@@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
                 file_path="output.wav",
                 speaker_wav="/path/to/target/speaker.wav",
                 language="en")
-
-# generate speech by cloning a voice using custom settings
-tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
-                file_path="output.wav",
-                speaker_wav="/path/to/target/speaker.wav",
-                language="en",
-                decoder_iterations=30)
 ```
 
-Using 🐸TTS Command line:
+#### 🐸TTS Command line
 
 ```console
  tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
@@ -58,25 +52,85 @@ Using 🐸TTS Command line:
      --use_cuda true
 ```
 
-Using model directly:
+#### model directly
+
+If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
+
+```console
+pip install deepspeed==0.8.3
+```
 
 ```python
+import os
+import torch
+import torchaudio
 from TTS.tts.configs.xtts_config import XttsConfig
 from TTS.tts.models.xtts import Xtts
 
+print("Loading model...")
 config = XttsConfig()
 config.load_json("/path/to/xtts/config.json")
 model = Xtts.init_from_config(config)
-model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
+model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
+model.cuda()
+    
+print("Computing speaker latents...")
+gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
+
+print("Inference...")
+out = model.inference(
+    "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
+    "en",
+    gpt_cond_latent,
+    speaker_embedding,
+    diffusion_conditioning,
+    temperature=0.7, # Add custom parameters here
+)
+torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
+```
+
+
+#### streaming inference
+
+Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
+Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
+
+
+```python
+import os
+import time
+import torch
+import torchaudio
+from TTS.tts.configs.xtts_config import XttsConfig
+from TTS.tts.models.xtts import Xtts
+
+print("Loading model...")
+config = XttsConfig()
+config.load_json("/path/to/xtts/config.json")
+model = Xtts.init_from_config(config)
+model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
 model.cuda()
 
-outputs = model.synthesize(
+print("Computing speaker latents...")
+gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
+
+print("Inference...")
+t0 = time.time()
+chunks = model.inference_stream(
     "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
-    config,
-    speaker_wav="/data/TTS-public/_refclips/3.wav",
-    gpt_cond_len=3,
-    language="en",
+    "en",
+    gpt_cond_latent,
+    speaker_embedding
 )
+    
+wav_chuncks = []
+for i, chunk in enumerate(chunks):
+    if i == 0:
+        print(f"Time to first chunck: {time.time() - t0}")
+    print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
+    wav_chuncks.append(chunk)
+wav = torch.cat(wav_chuncks, dim=0)
+torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
 ```
 
 
diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py
index 9c628276..dc16d793 100644
--- a/tests/zoo_tests/test_models.py
+++ b/tests/zoo_tests/test_models.py
@@ -93,6 +93,34 @@ def test_xtts():
             f'--speaker_wav "{speaker_wav}" --language_idx "en"'
         )
 
+def test_xtts_streaming():
+    """Testing the new inference_stream method"""
+    from TTS.tts.configs.xtts_config import XttsConfig
+    from TTS.tts.models.xtts import Xtts
+    speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
+    model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
+    config = XttsConfig()
+    config.load_json(os.path.join(model_path, "config.json"))
+    model = Xtts.init_from_config(config)
+    model.load_checkpoint(config, checkpoint_dir=model_path)
+    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
+
+    print("Computing speaker latents...")
+    gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
+
+    print("Inference...")
+    chunks = model.inference_stream(
+        "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
+        "en",
+        gpt_cond_latent,
+        speaker_embedding
+    )
+    wav_chuncks = []
+    for i, chunk in enumerate(chunks):
+        if i == 0:
+            assert chunk.shape[-1] > 5000
+        wav_chuncks.append(chunk)
+    assert len(wav_chuncks) > 1
 
 def test_tortoise():
     output_path = os.path.join(get_tests_output_path(), "output.wav")

From 0520697b5fdb01639ebf6a3314c91d2a8de96f22 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eren=20G=C3=B6lge?= <erogol@hotmail.com>
Date: Fri, 6 Oct 2023 18:35:26 +0200
Subject: [PATCH 30/37] v0.17.7

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 5543a76e..79394953 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.6
+0.17.7

From 4a6103fec9503024ed7d0bada7286e70c45a774a Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 6 Oct 2023 17:16:30 -0300
Subject: [PATCH 31/37] Redownload XTTS with the local and remote config do not
 match

---
 TTS/utils/manage.py | 68 +++++++++++++++++++++++++++++++--------------
 1 file changed, 47 insertions(+), 21 deletions(-)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index b5c698f3..dbd9d7c0 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -6,6 +6,7 @@ from pathlib import Path
 from shutil import copyfile, rmtree
 from typing import Dict, List, Tuple
 
+import fsspec
 import requests
 from tqdm import tqdm
 
@@ -320,6 +321,31 @@ class ModelManager(object):
             return False
         return True
 
+    def check_if_files_size(self, model_name):
+        pass
+
+    def create_dir_and_download_model(self, model_name, model_item, output_path):
+        os.makedirs(output_path, exist_ok=True)
+        # handle TOS
+        if not self.tos_agreed(model_item, output_path):
+            if not self.ask_tos(output_path):
+                os.rmdir(output_path)
+                raise Exception(" [!] You must agree to the terms of service to use this model.")
+        print(f" > Downloading model to {output_path}")
+        try:
+            if "fairseq" in model_name:
+                self.download_fairseq_model(model_name, output_path)
+            elif "github_rls_url" in model_item:
+                self._download_github_model(model_item, output_path)
+            elif "hf_url" in model_item:
+                self._download_hf_model(model_item, output_path)
+
+        except requests.RequestException as e:
+            print(f" > Failed to download the model file to {output_path}")
+            rmtree(output_path)
+            raise e
+        self.print_model_license(model_item=model_item)
+
     def download_model(self, model_name):
         """Download model files given the full model name.
         Model name is in the format
@@ -338,28 +364,28 @@ class ModelManager(object):
         # set the model specific output path
         output_path = os.path.join(self.output_prefix, model_full_name)
         if os.path.exists(output_path):
-            print(f" > {model_name} is already downloaded.")
-        else:
-            os.makedirs(output_path, exist_ok=True)
-            # handle TOS
-            if not self.tos_agreed(model_item, output_path):
-                if not self.ask_tos(output_path):
-                    os.rmdir(output_path)
-                    raise Exception(" [!] You must agree to the terms of service to use this model.")
-            print(f" > Downloading model to {output_path}")
-            try:
-                if "fairseq" in model_name:
-                    self.download_fairseq_model(model_name, output_path)
-                elif "github_rls_url" in model_item:
-                    self._download_github_model(model_item, output_path)
-                elif "hf_url" in model_item:
-                    self._download_hf_model(model_item, output_path)
+            # if the configs are different, redownload it
+            # ToDo: we need a better way to handle it
+            if "xtts_v1" in model_name:
+                with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
+                    config_local = json.load(f)
+                remote_url = None
+                for url in model_item["hf_url"]:
+                    if "config.json" in url:
+                        remote_url = url
+                        break
+
+                with fsspec.open(remote_url, "r", encoding="utf-8") as f:
+                    config_remote = json.load(f)
+
+                if not config_local == config_remote:
+                    print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
+                    self.create_dir_and_download_model(model_name, model_item, output_path)
+            else:
+                print(f" > {model_name} is already downloaded.")
+        else:
+            self.create_dir_and_download_model(model_name, model_item, output_path)
 
-            except requests.RequestException as e:
-                print(f" > Failed to download the model file to {output_path}")
-                rmtree(output_path)
-                raise e
-            self.print_model_license(model_item=model_item)
         # find downloaded files
         output_model_path = output_path
         output_config_path = None

From ee1ef1c51e183c9236d4b598afda0e8c828a8ec9 Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 6 Oct 2023 17:21:22 -0300
Subject: [PATCH 32/37] Remove unused method

---
 TTS/utils/manage.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index dbd9d7c0..288cc118 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -321,9 +321,6 @@ class ModelManager(object):
             return False
         return True
 
-    def check_if_files_size(self, model_name):
-        pass
-
     def create_dir_and_download_model(self, model_name, model_item, output_path):
         os.makedirs(output_path, exist_ok=True)
         # handle TOS

From 529ea3f67f584cdc77c9d65a45da1e12043f970b Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 6 Oct 2023 17:26:40 -0300
Subject: [PATCH 33/37] Print a message when it is already donwloaded

---
 TTS/utils/manage.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index 288cc118..fedab1d3 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -378,6 +378,8 @@ class ModelManager(object):
                 if not config_local == config_remote:
                     print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
                     self.create_dir_and_download_model(model_name, model_item, output_path)
+                else:
+                    print(f" > {model_name} is already downloaded.")        
             else:
                 print(f" > {model_name} is already downloaded.")
         else:

From 99650044a432c619f4ae63258d62bea053a58fe3 Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 6 Oct 2023 17:37:05 -0300
Subject: [PATCH 34/37] Try-except to present error when the user dont have
 connection

---
 TTS/utils/manage.py | 38 ++++++++++++++++++++++----------------
 1 file changed, 22 insertions(+), 16 deletions(-)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index fedab1d3..0420fc0d 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -343,6 +343,24 @@ class ModelManager(object):
             raise e
         self.print_model_license(model_item=model_item)
 
+    def check_if_configs_are_equal(self, model_name, model_item, output_path):
+        with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
+            config_local = json.load(f)
+        remote_url = None
+        for url in model_item["hf_url"]:
+            if "config.json" in url:
+                remote_url = url
+                break
+
+        with fsspec.open(remote_url, "r", encoding="utf-8") as f:
+            config_remote = json.load(f)
+
+        if not config_local == config_remote:
+            print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
+            self.create_dir_and_download_model(model_name, model_item, output_path)
+        else:
+            print(f" > {model_name} is already downloaded.")
+
     def download_model(self, model_name):
         """Download model files given the full model name.
         Model name is in the format
@@ -364,22 +382,10 @@ class ModelManager(object):
             # if the configs are different, redownload it
             # ToDo: we need a better way to handle it
             if "xtts_v1" in model_name:
-                with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
-                    config_local = json.load(f)
-                remote_url = None
-                for url in model_item["hf_url"]:
-                    if "config.json" in url:
-                        remote_url = url
-                        break
-
-                with fsspec.open(remote_url, "r", encoding="utf-8") as f:
-                    config_remote = json.load(f)
-
-                if not config_local == config_remote:
-                    print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
-                    self.create_dir_and_download_model(model_name, model_item, output_path)
-                else:
-                    print(f" > {model_name} is already downloaded.")        
+                try:
+                    self.check_if_configs_are_equal(model_name, model_item, output_path)
+                except:
+                    pass   
             else:
                 print(f" > {model_name} is already downloaded.")
         else:

From 2852404bdffbb8bc970444a855273647ea2e78c4 Mon Sep 17 00:00:00 2001
From: Edresson Casanova <edresson1@gmail.com>
Date: Fri, 6 Oct 2023 17:42:46 -0300
Subject: [PATCH 35/37] Fix style

---
 TTS/utils/manage.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index 0420fc0d..955eeb9b 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -385,7 +385,7 @@ class ModelManager(object):
                 try:
                     self.check_if_configs_are_equal(model_name, model_item, output_path)
                 except:
-                    pass   
+                    pass
             else:
                 print(f" > {model_name} is already downloaded.")
         else:

From 3bb51b1276b1f203a65df6e9fcedcfa2f0b23dd9 Mon Sep 17 00:00:00 2001
From: ggoknar <ggoknar@coqui.ai>
Date: Sat, 7 Oct 2023 01:13:02 +0300
Subject: [PATCH 36/37] 0.17.8

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 79394953..7df1fd55 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-0.17.7
+v0.17.8

From 99635193f508092c746febb087dc6634fa5f59d8 Mon Sep 17 00:00:00 2001
From: ggoknar <ggoknar@coqui.ai>
Date: Sat, 7 Oct 2023 01:14:05 +0300
Subject: [PATCH 37/37] v0.17.8

---
 TTS/VERSION | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/TTS/VERSION b/TTS/VERSION
index 7df1fd55..8bb22944 100644
--- a/TTS/VERSION
+++ b/TTS/VERSION
@@ -1 +1 @@
-v0.17.8
+0.17.8