diff --git a/.github/workflows/zoo_tests_tortoise.yml b/.github/workflows/zoo_tests_tortoise.yml new file mode 100644 index 00000000..31442877 --- /dev/null +++ b/.github/workflows/zoo_tests_tortoise.yml @@ -0,0 +1,52 @@ +name: zoo-tests-tortoise + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] +jobs: + check_skip: + runs-on: ubuntu-latest + if: "! contains(github.event.head_commit.message, '[ci skip]')" + steps: + - run: echo "${{ github.event.head_commit.message }}" + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.9, "3.10", "3.11"] + experimental: [false] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + architecture: x64 + cache: 'pip' + cache-dependency-path: 'requirements*' + - name: check OS + run: cat /etc/os-release + - name: set ENV + run: export TRAINER_TELEMETRY=0 + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y git make gcc + sudo apt-get install espeak espeak-ng + make system-deps + - name: Install/upgrade Python setup deps + run: python3 -m pip install --upgrade pip setuptools wheel + - name: Replace scarf urls + run: | + sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json + - name: Install TTS + run: | + python3 -m pip install .[all] + python3 setup.py egg_info + - name: Unit tests + run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise diff --git a/README.md b/README.md index 474f5499..720585db 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ ## 🐸Coqui.ai News -- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with uncontrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html) +- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html) +- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html) - 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. - 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html) - 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api) @@ -111,7 +112,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea - Delightful TTS: [paper](https://arxiv.org/abs/2110.12612) ### End-to-End Models -- ⓍTTS: [blog]() +- ⓍTTS: [blog](https://coqui.ai/blog/tts/open_xtts) - VITS: [paper](https://arxiv.org/pdf/2106.06103) - 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418) - 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts) @@ -293,99 +294,123 @@ api.tts_with_vc_to_file( ``` ### Command-line `tts` + + + +Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + #### Single Speaker Models - List provided models: - ``` - $ tts --list_models - ``` + ``` + $ tts --list_models + ``` + - Get model info (for both tts_models and vocoder_models): - - Query by type/name: - The model_info_by_name uses the name as it from the --list_models. - ``` - $ tts --model_info_by_name "///" - ``` - For example: - ``` - $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts - ``` - ``` - $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 - ``` - - Query by type/idx: - The model_query_idx uses the corresponding idx from --list_models. - ``` - $ tts --model_info_by_idx "/" - ``` - For example: + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. - ``` - $ tts --model_info_by_idx tts_models/3 - ``` + ``` + $ tts --model_info_by_idx "/" + ``` + + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + + - Query info for model info by full name: + ``` + $ tts --model_info_by_name "///" + ``` - Run TTS with default models: - ``` - $ tts --text "Text for TTS" --out_path output/path/speech.wav - ``` + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav + ``` - Run a TTS model with its default vocoder model: - ``` - $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav - ``` + ``` + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + For example: - ``` - $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav - ``` + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav + ``` - Run with specific TTS and vocoder models from the list: - ``` - $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav - ``` + ``` + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav + ``` For example: - ``` - $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav - ``` - + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` - Run your own TTS model (Using Griffin-Lim Vocoder): - ``` - $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav - ``` + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` - Run your own TTS and Vocoder models: - ``` - $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav - --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json - ``` + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` #### Multi-speaker Models - List the available speakers and choose a among them: - ``` - $ tts --model_name "//" --list_speaker_idxs - ``` + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` - Run the multi-speaker TTS model with the target speaker ID: - ``` - $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx - ``` + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` - Run your own multi-speaker TTS model: - ``` - $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx - ``` + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx + ``` + +### Voice Conversion Models + +``` +$ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav +``` + + ## Directory Structure ``` diff --git a/TTS/.models.json b/TTS/.models.json index 1eaaab71..ba7b5f62 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -5,12 +5,12 @@ "xtts_v1": { "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.", "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json" + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json", + "https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json" ], "default_vocoder": null, - "commit": "e9a1953e", + "commit": "e5140314", "license": "CPML", "contact": "info@coqui.ai", "tos_required": true @@ -728,6 +728,18 @@ "license": "Apache 2.0" } } + }, + "be": { + "common-voice": { + "glow-tts":{ + "description": "Belarusian GlowTTS model created by @alex73 (Github).", + "github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/be/common-voice/hifigan", + "commit": "c0aabb85", + "license": "CC-BY-SA 4.0", + "contact": "alex73mail@gmail.com" + } + } } }, "vocoder_models": { @@ -879,6 +891,17 @@ "commit": null } } + }, + "be": { + "common-voice": { + "hifigan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip", + "description": "Belarusian HiFiGAN model created by @alex73 (Github).", + "author": "@alex73", + "license": "CC-BY-SA 4.0", + "commit": "c0aabb85" + } + } } }, "voice_conversion_models": { diff --git a/TTS/VERSION b/TTS/VERSION index c3d16c16..8bb22944 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.17.2 +0.17.8 diff --git a/TTS/api.py b/TTS/api.py index 1eb0b510..e1d167a9 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -17,7 +17,7 @@ class TTS(nn.Module): def __init__( self, - model_name: str = None, + model_name: str = "", model_path: str = None, config_path: str = None, vocoder_path: str = None, @@ -105,13 +105,14 @@ class TTS(nn.Module): @property def is_multi_lingual(self): - # TODO: fix this - if "xtts" in self.model_name: + # Not sure what sets this to None, but applied a fix to prevent crashing. + if isinstance(self.model_name, str) and "xtts" in self.model_name: return True if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: return self.synthesizer.tts_model.language_manager.num_languages > 1 return False + @property def speakers(self): if not self.is_multi_speaker: diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index e8de18b0..5ff1181f 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -8,9 +8,120 @@ from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path -from TTS.api import TTS -from TTS.utils.manage import ModelManager -from TTS.utils.synthesizer import Synthesizer +description = """ +Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + +#### Single Speaker Models + +- List provided models: + + ``` + $ tts --list_models + ``` + +- Get model info (for both tts_models and vocoder_models): + + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. + + ``` + $ tts --model_info_by_idx "/" + ``` + + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + + - Query info for model info by full name: + ``` + $ tts --model_info_by_name "///" + ``` + +- Run TTS with default models: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav + ``` + +- Run a TTS model with its default vocoder model: + + ``` + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav + ``` + +- Run with specific TTS and vocoder models from the list: + + ``` + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` + +- Run your own TTS model (Using Griffin-Lim Vocoder): + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` + +- Run your own TTS and Vocoder models: + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` + +#### Multi-speaker Models + +- List the available speakers and choose a among them: + + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` + +- Run the multi-speaker TTS model with the target speaker ID: + + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` + +- Run your own multi-speaker TTS model: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx + ``` + +### Voice Conversion Models + +``` +$ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav +``` +""" def str2bool(v): @@ -24,92 +135,6 @@ def str2bool(v): def main(): - description = """Synthesize speech on command line. - -You can either use your trained model or choose a model from the provided list. - -If you don't specify any models, then it uses LJSpeech based English model. - -## Example Runs - -### Single Speaker Models - -- List provided models: - - ``` - $ tts --list_models - ``` - -- Query info for model info by idx: - - ``` - $ tts --model_info_by_idx "/" - ``` - -- Query info for model info by full name: - - ``` - $ tts --model_info_by_name "///" - ``` - -- Run TTS with default models: - - ``` - $ tts --text "Text for TTS" - ``` - -- Run a TTS model with its default vocoder model: - - ``` - $ tts --text "Text for TTS" --model_name "/// - ``` - -- Run with specific TTS and vocoder models from the list: - - ``` - $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --output_path - ``` - -- Run your own TTS model (Using Griffin-Lim Vocoder): - - ``` - $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav - ``` - -- Run your own TTS and Vocoder models: - ``` - $ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav - --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json - ``` - -### Multi-speaker Models - -- List the available speakers and choose as among them: - - ``` - $ tts --model_name "//" --list_speaker_idxs - ``` - -- Run the multi-speaker TTS model with the target speaker ID: - - ``` - $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx - ``` - -- Run your own multi-speaker TTS model: - - ``` - $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx - ``` - -### Voice Conversion Models - - ``` - $ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav - ``` - """ - # We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep - # documentation in sync more easily. parser = argparse.ArgumentParser( description=description.replace(" ```\n", ""), formatter_class=RawTextHelpFormatter, @@ -310,6 +335,11 @@ If you don't specify any models, then it uses LJSpeech based English model. if not any(check_args): parser.parse_args(["-h"]) + # Late-import to make things load faster + from TTS.api import TTS + from TTS.utils.manage import ModelManager + from TTS.utils.synthesizer import Synthesizer + # load model manager path = Path(__file__).parent / "../.models.json" manager = ModelManager(path, progress_bar=args.progress_bar) diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index eb9e90df..cb350af7 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -1085,31 +1085,6 @@ class GaussianDiffusion: } -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - """ - Get a pre-defined beta schedule for the given name. - - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ - if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. - scale = 1000 / num_diffusion_timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - class SpacedDiffusion(GaussianDiffusion): """ A diffusion process which can skip steps in a base diffusion process. diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py index 3969b2cc..d243d655 100644 --- a/TTS/tts/layers/tortoise/tokenizer.py +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -5,9 +5,13 @@ from tokenizers import Tokenizer from TTS.tts.utils.text.cleaners import english_cleaners +DEFAULT_VOCAB_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" +) + class VoiceBpeTokenizer: - def __init__(self, vocab_file=None, vocab_str=None): + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None): self.tokenizer = None if vocab_file is not None: self.tokenizer = Tokenizer.from_file(vocab_file) diff --git a/TTS/tts/layers/xtts/diffusion.py b/TTS/tts/layers/xtts/diffusion.py index a0b93add..37665bc6 100644 --- a/TTS/tts/layers/xtts/diffusion.py +++ b/TTS/tts/layers/xtts/diffusion.py @@ -1170,31 +1170,6 @@ class GaussianDiffusion: } -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - """ - Get a pre-defined beta schedule for the given name. - - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ - if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. - scale = 1000 / num_diffusion_timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - class SpacedDiffusion(GaussianDiffusion): """ A diffusion process which can skip steps in a base diffusion process. diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 2a821a5d..88ce100c 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -172,7 +172,7 @@ class GPT(nn.Module): "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), } - def init_gpt_for_inference(self, kv_cache=True): + def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False): seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 gpt_config = GPT2Config( vocab_size=self.max_mel_tokens, @@ -195,6 +195,17 @@ class GPT(nn.Module): ) self.gpt.wte = self.mel_embedding + if use_deepspeed: + import deepspeed + self.ds_engine = deepspeed.init_inference( + model=self.gpt_inference.half(), # Transformers models + mp_size=1, # Number of GPU + dtype=torch.float32, # desired data type of output + replace_method="auto", # Lets DS autmatically identify the layer to replace + replace_with_kernel_inject=True, # replace the model with the kernel injector + ) + self.gpt_inference = self.ds_engine.module.eval() + def set_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1, 0), value=start_token) tar = F.pad(input, (0, 1), value=stop_token) @@ -543,3 +554,14 @@ class GPT(nn.Module): if "return_dict_in_generate" in hf_generate_kwargs: return gen.sequences[:, gpt_inputs.shape[1] :], gen return gen[:, gpt_inputs.shape[1] :] + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.gpt_inference.generate_stream( + fake_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + do_stream=True, + **hf_generate_kwargs, + ) diff --git a/TTS/tts/layers/xtts/gpt_encoder_eren.py b/TTS/tts/layers/xtts/gpt_encoder_eren.py deleted file mode 100644 index b5e7158d..00000000 --- a/TTS/tts/layers/xtts/gpt_encoder_eren.py +++ /dev/null @@ -1,658 +0,0 @@ -import functools - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions - - -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -class GPT2InferenceModel(GPT2PreTrainedModel): - """Override GPT2LMHeadModel to allow for prefix conditioning.""" - - def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): - super().__init__(config) - self.transformer = gpt - self.pos_embedding = pos_emb - self.embeddings = embeddings - self.final_norm = norm - self.lm_head = nn.Sequential(norm, linear) - self.kv_cache = kv_cache - - def store_prefix_emb(self, prefix_emb): - self.cached_prefix_emb = prefix_emb - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) # usually None - if not self.kv_cache: - past_key_values = None - - # only last token for inputs_ids if past is defined in kwargs - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values is not None: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - assert self.cached_prefix_emb is not None - assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] - - # Create embedding - prefix_len = self.cached_prefix_emb.shape[1] - if input_ids.shape[1] != 1: - gen_inputs = input_ids[:, prefix_len:] - gen_emb = self.embeddings(gen_inputs) - gen_emb = gen_emb + self.pos_embedding(gen_emb) - if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: - prefix_emb = self.cached_prefix_emb.repeat_interleave( - gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 - ) - else: - prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) - emb = torch.cat([prefix_emb, gen_emb], dim=1) - else: - emb = self.embeddings(input_ids) - emb = emb + self.pos_embedding.get_fixed_embedding( - attention_mask.shape[1] - (prefix_len + 1), attention_mask.device - ) - transformer_outputs = self.transformer( - inputs_embeds=emb, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + transformer_outputs[1:] - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - """ - This function is used to re-order the :obj:`past_key_values` cache if - :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_channels, init_std=0.02, relative=False): - super().__init__() - self.emb = nn.Embedding(seq_len, model_channels) - nn.init.normal_(self.emb.weight, mean=0.0, std=init_std) - self.relative = relative - - def forward(self, x): - seq_len = x.shape[1] - if self.relative: - start = torch.randint(seq_len, (1,), device=x.device).item() - positions = torch.arange(start, start + seq_len, device=x.device) - else: - positions = torch.arange(seq_len, device=x.device) - return self.emb(positions) - - def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) - - -def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing): - """ - Initializes a GPT-2 model and its position embeddings for a text-to-speech system. - - Args: - layers (int): Number of layers in the GPT-2 model. - model_channels (int): Dimension of the GPT-2 model. - heads (int): Number of heads in the GPT-2 model. - max_mel_seq_len (int): Maximum sequence length for the mel spectrogram. - max_text_seq_len (int): Maximum sequence length for the text. - max_prompt_len (int): Maximum length of the prompt. - checkpointing (bool): Whether to use gradient checkpointing. - - Returns: - gpt (GPT2Model): GPT-2 model. - mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram. - text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text. - """ - gpt_config = GPT2Config( - vocab_size=123, - n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_embd=model_channels, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing, - ) - gpt = GPT2Model(gpt_config) - - del gpt.wpe - del gpt.wte - - gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels) - - audio_pos_emb = ( - LearnedPositionEmbeddings(max_mel_seq_len, model_channels) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_channels) - ) - text_pos_emb = ( - LearnedPositionEmbeddings(max_text_seq_len, model_channels) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_channels) - ) - - return gpt, audio_pos_emb, text_pos_emb - - -class XTTSGPTEncoder(nn.Module): - """XTTS GPT Encoder model implementation. - Args: - start_text_token (int): Index of the start token in the text vocabulary. - stop_text_token (int): Index of the stop token in the text vocabulary. - n_layers (int): Number of layers in the GPT-2 model. - n_model_channels (int): Dimension of the GPT-2 model. - n_heads (int): Number of heads in the GPT-2 model. - max_text_tokens (int): Maximum number of text tokens. - max_audio_tokens (int): Maximum number of audio tokens. - max_prompt_tokens (int): Maximum number of prompt tokens. - audio_len_compression (int): Compression factor for the audio length. - number_text_tokens (int): Number of text tokens. - number_audio_codes (int): Number of audio codes. - start_mel_token (int): Index of the start token in the mel code vocabulary. - stop_mel_token (int): Index of the stop token in the mel code vocabulary. - checkpointing (bool): Whether or not to use gradient checkpointing at training. - """ - - _inference_flag = False - - def __init__( - self, - start_text_token=261, - stop_text_token=0, - n_layers=8, - n_model_channels=512, - n_heads=8, - max_text_tokens=120, - max_audio_tokens=250, - max_prompt_tokens=70, - audio_len_compression=1024, - number_text_tokens=256, - number_audio_codes=8194, - start_mel_token=8192, - stop_mel_token=8193, - checkpointing=True, - label_smoothing=0.0, - ): - super().__init__() - - self.label_smoothing = label_smoothing - self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token - self.number_audio_codes = number_audio_codes - self.start_mel_token = start_mel_token - self.stop_mel_token = stop_mel_token - self.start_prompt_token = start_mel_token - self.stop_prompt_token = stop_mel_token - self.n_layers = n_layers - self.n_heads = n_heads - self.n_model_channels = n_model_channels - self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs - self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 - self.max_prompt_tokens = max_prompt_tokens - self.audio_len_compression = audio_len_compression - - # embedding layers - self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels) - self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels) - self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels) - self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels) - - # initialize the GPT-2 model - ( - self.gpt, - self.audio_pos_embedding, - self.text_pos_embedding, - ) = init_gpt( - n_layers, - n_model_channels, - n_heads, - self.max_audio_tokens, - self.max_text_tokens, - self.max_prompt_tokens, - checkpointing, - ) - - # output layers - self.final_norm = nn.LayerNorm(n_model_channels) - self.text_head = nn.Linear(n_model_channels, self.number_text_tokens) - self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes) - - def get_grad_norm_parameter_groups(self): - return { - "conditioning_encoder": list(self.conditioning_encoder.parameters()), - "gpt": list(self.gpt.parameters()), - "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), - } - - def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False): - self._inference_flag = True - seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens - gpt_config = GPT2Config( - vocab_size=self.max_audio_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.n_model_channels, - n_layer=self.n_layers, - n_head=self.n_heads, - gradient_checkpointing=False, - use_cache=True, - ) - self.inference_model = GPT2InferenceModel( - gpt_config, - self.gpt, - self.audio_pos_embedding, - self.audio_embedding, - self.final_norm, - self.mel_head, - kv_cache=kv_cache, - ) - self.gpt.wte = self.audio_embedding - - def set_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1, 0), value=start_token) - tar = F.pad(input, (0, 1), value=stop_token) - return inp, tar - - def set_audio_tokens_padding(self, audio_tokens, audio_token_lens): - # Set padding areas within MEL (currently it is coded with the MEL code for ). - for b in range(len(audio_token_lens)): - actual_end = audio_token_lens[b] - if actual_end < audio_tokens.shape[-1]: - audio_tokens[b, actual_end:] = self.stop_mel_token - return audio_tokens - - def get_logits( - self, - speech_conditioning_inputs, - first_inputs, - first_head, - second_inputs=None, - second_head=None, - prompt=None, - get_attns=False, - return_latent=False, - attn_mask_text=None, - attn_mask_mel=None, - ): - if prompt is not None and speech_conditioning_inputs is not None: - offset = speech_conditioning_inputs.shape[1] + prompt.shape[1] - if second_inputs is not None: - emb = torch.cat( - [speech_conditioning_inputs, prompt, first_inputs, second_inputs], - dim=1, - ) - else: - emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1) - elif speech_conditioning_inputs is not None: - offset = speech_conditioning_inputs.shape[1] - if second_inputs is not None: - emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) - elif prompt is not None: - offset = prompt.shape[1] - if second_inputs is not None: - emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([prompt, first_inputs], dim=1) - - # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): - attn_mask = None - if attn_mask_text is not None: - attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) - if prompt is not None: - attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) - attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) - - gpt_out = self.gpt( - inputs_embeds=emb, - return_dict=True, - output_attentions=get_attns, - attention_mask=attn_mask, - ) - - if get_attns: - return gpt_out.attentions - - enc = gpt_out.last_hidden_state[:, offset:] - enc = self.final_norm(enc) - - if return_latent: - return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] - - first_logits = enc[:, : first_inputs.shape[1]] - first_logits = first_head(first_logits) - first_logits = first_logits.permute(0, 2, 1) - if second_inputs is not None: - second_logits = enc[:, -second_inputs.shape[1] :] - second_logits = second_head(second_logits) - second_logits = second_logits.permute(0, 2, 1) - return first_logits, second_logits - else: - return first_logits - - def get_conditioning(self, speech_conditioning_input): - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - conds = conds.mean(dim=1) - return conds - - def get_prompts(self, prompt_codes): - prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token) - prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token) - return prompt - - def forward( - self, - text_inputs, - text_lengths, - audio_codes, - wav_lengths, - prompt_codes, - return_attentions=False, - return_latent=False, - ): - max_text_len = text_lengths.max() - - # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values - # Like [..., 186, 45, 45, 83] where actually it should end with 186. - # We take last 3 codes to prevent abrupt ending of the audio. - # TODO: This is might need some testing. - mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3 - - # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. - max_mel_len = mel_lengths.max() - - if max_mel_len > audio_codes.shape[-1]: - audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1])) - - # silence aware lengths, skip the silence tokens at the end of the mel codes. - silence = True - for idx, l in enumerate(mel_lengths): - length = l.item() - while silence: - if audio_codes[idx, length - 1] != 83: - break - length -= 1 - mel_lengths[idx] = length - - # Lovely assertions - assert ( - max_mel_len <= audio_codes.shape[-1] - ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})" - assert ( - max_text_len <= text_inputs.shape[-1] - ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" - - # Append stop token to text inputs - text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) - - # Append silence token to mel codes - audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) - - # Pad mel codes with STOP_MEL_TOKEN - audio_codes = self.set_mel_padding(audio_codes, mel_lengths) - - # Compute speech conditioning input - conds = None - if speech_conditioning_input is not None: - if not return_latent: - # Compute speech conditioning input - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - else: - # already computed - conds = speech_conditioning_input.unsqueeze(1) - - # Build input and target tensors - # Prepend start token to inputs and append stop token to targets - text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) - audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token) - - # Set attn_mask - attn_mask_text = None - attn_mask_mel = None - if not return_latent: - attn_mask_text = torch.ones( - text_inputs.shape[0], - text_inputs.shape[1], - dtype=torch.bool, - device=text_inputs.device, - ) - attn_mask_mel = torch.ones( - audio_codes.shape[0], - audio_codes.shape[1], - dtype=torch.bool, - device=audio_codes.device, - ) - - for idx, l in enumerate(text_lengths): - attn_mask_text[idx, l + 1 :] = 0.0 - - for idx, l in enumerate(mel_lengths): - attn_mask_mel[idx, l + 1 :] = 0.0 - - # Compute text embeddings + positional embeddings - # print(" > text input latent:", text_inputs) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - # Compute mel embeddings + positional embeddings - audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes) - - # Compute prompt embeddings + positional embeddings - prompt = self.get_prompts(prompt_codes) - - # prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach() - prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt) - - # dropout prompt embeddings - prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training) - - # Get logits - sub = -4 # don't ask me why 😄 - if self.training: - sub = -1 - _, audio_logits = self.get_logits( - conds, - text_emb, - self.text_head, - audio_emb, - self.mel_head, - prompt=prompt_emb, - get_attns=return_attentions, - return_latent=return_latent, - attn_mask_text=attn_mask_text, - attn_mask_mel=attn_mask_mel, - ) - return audio_logits[:, :sub] # sub to prevent bla. - - def compute_embeddings( - self, - speech_conditioning_latent, - text_inputs, - input_tokens=None, - prompt_codes=None, - pad_input_text=False, - ): - """Compute all the embeddings needed for inference.""" - if pad_input_text and text_inputs.shape[1] < 250: - text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) - else: - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) - - emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - print(" > Text inputs:", text_inputs) - if prompt_codes is not None: - prompt_codes = self.get_prompts(prompt_codes) - # prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) - prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes) - - print(" > Prompt inputs:", prompt_codes) - print(" > Prompt inputs shape:", prompt_codes.shape) - emb = torch.cat([prompt_emb, emb], dim=1) - - if speech_conditioning_latent is not None: - conds = speech_conditioning_latent.unsqueeze(1) - emb = torch.cat([conds, emb], dim=1) - - self.inference_model.store_prefix_emb(emb) - - fake_inputs = torch.full( - ( - emb.shape[0], - emb.shape[1] + 1, # +1 for the start_mel_token - ), - fill_value=1, - dtype=torch.long, - device=text_inputs.device, - ) - fake_inputs[:, -1] = self.start_mel_token - - if input_tokens is not None: - fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) - return fake_inputs - - def inference( - self, - text_inputs, - input_tokens=None, - prompt_codes=None, - pad_input_text=False, - **hf_generate_kwargs, - ): - if pad_input_text and text_inputs.shape[1] < 250: - text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) - else: - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) - - emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - if prompt_codes is not None: - prompt_codes = self.get_prompts(prompt_codes) - prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes) - emb = torch.cat([prompt_emb, emb], dim=1) - - self.inference_model.store_prefix_emb(emb) - - fake_inputs = torch.full( - ( - emb.shape[0], - emb.shape[1] + 1, # +1 for the start_mel_token - ), - fill_value=1, - dtype=torch.long, - device=text_inputs.device, - ) - fake_inputs[:, -1] = self.start_mel_token - - if input_tokens is not None: - fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) - - gen = self.inference_model.generate( - fake_inputs, - bos_token_id=self.start_mel_token, - pad_token_id=self.stop_mel_token, - eos_token_id=self.stop_mel_token, - max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, - **hf_generate_kwargs, - ) - if "return_dict_in_generate" in hf_generate_kwargs: - return gen.sequences[:, fake_inputs.shape[1] :], gen - return gen[:, fake_inputs.shape[1] :] diff --git a/TTS/tts/layers/xtts/gpt_encoder_old.py b/TTS/tts/layers/xtts/gpt_encoder_old.py deleted file mode 100644 index 46739aa2..00000000 --- a/TTS/tts/layers/xtts/gpt_encoder_old.py +++ /dev/null @@ -1,1057 +0,0 @@ -import functools -import math -import random - -import torch -import torch.nn as nn -import torch.nn.functional as F - -try: - import deepspeed - from deepspeed.ops.transformer.inference import DeepSpeedTransformerInferenceKernel -except ImportError: - pass - -import dlas.codes.torch_intermediary as ml -from dlas.codes.models.arch_util import AttentionBlock -from dlas.codes.trainer.networks import register_model -from dlas.codes.utils.transformers.stream_generator import init_stream_support -from dlas.codes.utils.util import opt_get -from transformers import GPT2Config, GPT2PreTrainedModel -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions - -init_stream_support() - - -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -class ResBlock(nn.Module): - """ - Basic residual convolutional block that uses GroupNorm. - """ - - def __init__(self, chan): - super().__init__() - self.net = nn.Sequential( - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan // 8, chan), - nn.ReLU(), - nn.Conv1d(chan, chan, kernel_size=3, padding=1), - nn.GroupNorm(chan // 8, chan), - ) - - def forward(self, x): - return F.relu(self.net(x) + x) - - -class GPT2InferenceModel(GPT2PreTrainedModel): - """Override GPT2LMHeadModel to allow for prefix conditioning.""" - - def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): - super().__init__(config) - self.transformer = gpt - self.pos_embedding = pos_emb - self.embeddings = embeddings - self.final_norm = norm - self.lm_head = nn.Sequential(norm, linear) - self.kv_cache = kv_cache - - def store_prefix_emb(self, prefix_emb): - self.cached_prefix_emb = prefix_emb - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) # usually None - if not self.kv_cache: - past_key_values = None - - # only last token for inputs_ids if past is defined in kwargs - if past_key_values is not None: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values is not None: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - - def forward( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - assert self.cached_prefix_emb is not None - assert inputs_embeds is None # Not supported by this inference model. - assert labels is None # Training not supported by this inference model. - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] - - # Create embedding - prefix_len = self.cached_prefix_emb.shape[1] - if input_ids.shape[1] != 1: - gen_inputs = input_ids[:, prefix_len:] - gen_emb = self.embeddings(gen_inputs) - gen_emb = gen_emb + self.pos_embedding(gen_emb) - if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: - prefix_emb = self.cached_prefix_emb.repeat_interleave( - gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 - ) - else: - prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) - emb = torch.cat([prefix_emb, gen_emb], dim=1) - else: - emb = self.embeddings(input_ids) - emb = emb + self.pos_embedding.get_fixed_embedding( - attention_mask.shape[1] - (prefix_len + 1), attention_mask.device - ) - transformer_outputs = self.transformer( - inputs_embeds=emb, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return (lm_logits,) + transformer_outputs[1:] - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - ) - - @staticmethod - def _reorder_cache(past, beam_idx): - """ - This function is used to re-order the :obj:`past_key_values` cache if - :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past - ) - - -class ConditioningEncoder(nn.Module): - def __init__( - self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - do_checkpointing=False, - mean=False, - ): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - self.do_checkpointing = do_checkpointing - self.mean = mean - - def forward(self, x): - h = self.init(x) - h = self.attn(h) - if self.mean: - return h.mean(dim=2) - else: - return h[:, :, 0] - - -class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=0.02, relative=False): - super().__init__() - # nn.Embedding - self.emb = torch.nn.Embedding(seq_len, model_dim) - # Initializing this way is standard for GPT-2 - self.emb.weight.data.normal_(mean=0.0, std=init) - self.relative = relative - self.seq_len = seq_len - - def forward(self, x): - sl = x.shape[1] - if self.relative: - start = random.randint(sl, self.seq_len) - sl - return self.emb(torch.arange(start, start + sl, device=x.device)) - else: - return self.emb(torch.arange(0, sl, device=x.device)) - - def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) - - -def build_hf_gpt_transformer( - layers, - model_dim, - heads, - max_mel_seq_len, - max_text_seq_len, - max_prompt_len, - checkpointing, -): - """ - GPT-2 implemented by the HuggingFace library. - """ - from transformers import GPT2Config, GPT2Model - - gpt_config = GPT2Config( - vocab_size=256, # Unused. - n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing, - ) - gpt = GPT2Model(gpt_config) - # Override the built in positional embeddings - del gpt.wpe - gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - # Built-in token embeddings are unused. - del gpt.wte - - # def _attn(self, query, key, value, attention_mask=None, head_mask=None): - # attn_output = torch.nn.functional.scaled_dot_product_attention( - # query, key, value, dropout_p=self.attn_dropout.p, is_causal=True - # ) - # return attn_output, None - - # for i in range(len(gpt.h)): - # gpt.h[i].attn._attn = types.MethodType( - # _attn, gpt.h[i].attn - # ) - - mel_pos_emb = ( - LearnedPositionEmbeddings(max_mel_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - text_pos_emb = ( - LearnedPositionEmbeddings(max_text_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) - return gpt, mel_pos_emb, text_pos_emb, None, None - - -class MelEncoder(nn.Module): - def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): - super().__init__() - self.channels = channels - self.encoder = nn.Sequential( - nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), - nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels // 16, channels // 2), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), - nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(channels // 8, channels), - nn.ReLU(), - nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), - ) - self.reduction = 4 - - def forward(self, x): - for e in self.encoder: - x = e(x) - return x.permute(0, 2, 1) - - -class UnifiedVoice(nn.Module): - def __init__( - self, - start_text_token=261, - stop_text_token=0, - layers=8, - model_dim=512, - heads=8, - max_text_tokens=120, - max_mel_tokens=250, - max_prompt_tokens=70, - max_conditioning_inputs=1, - mel_length_compression=1024, - number_text_tokens=256, - number_mel_codes=8194, - start_mel_token=8192, - stop_mel_token=8193, - train_solo_embeddings=False, - use_mel_codes_as_input=True, - checkpointing=True, - average_conditioning_embeddings=False, - freeze_everything_but_position_embeddings=False, - freeze_conditioning_encoder=False, - tortoise_compat=True, - label_smoothing=0.0, - ): - """ - Args: - layers: Number of layers in transformer stack. - model_dim: Operating dimensions of the transformer - heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 - max_text_tokens: Maximum number of text tokens that will be encountered by model. - max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. - max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). - mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. - number_text_tokens: - start_text_token: - stop_text_token: - number_mel_codes: - start_mel_token: - stop_mel_token: - train_solo_embeddings: - use_mel_codes_as_input: - checkpointing: - average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model. - """ - super().__init__() - - self.label_smoothing = label_smoothing - self.number_text_tokens = number_text_tokens - self.start_text_token = start_text_token - self.stop_text_token = stop_text_token - self.number_mel_codes = number_mel_codes - self.start_mel_token = start_mel_token - self.stop_mel_token = stop_mel_token - self.start_prompt_token = start_mel_token - self.stop_prompt_token = stop_mel_token - self.layers = layers - self.heads = heads - self.model_dim = model_dim - self.max_conditioning_inputs = max_conditioning_inputs - self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs - self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 - self.max_prompt_tokens = max_prompt_tokens - self.mel_length_compression = mel_length_compression - # self.conditioning_encoder = ConditioningEncoder( - # 80, model_dim, num_attn_heads=heads - # ) - self.average_conditioning_embeddings = average_conditioning_embeddings - self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b - # nn.Embedding - self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim) - if use_mel_codes_as_input: - # nn.Embedding - self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) - else: - self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) - ( - self.gpt, - self.mel_pos_embedding, - self.text_pos_embedding, - self.mel_layer_pos_embedding, - self.text_layer_pos_embedding, - ) = build_hf_gpt_transformer( - layers, - model_dim, - heads, - self.max_mel_tokens, - self.max_text_tokens, - self.max_prompt_tokens, - checkpointing, - ) - if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) - self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) - else: - self.mel_solo_embedding = 0 - self.text_solo_embedding = 0 - - self.final_norm = nn.LayerNorm(model_dim) - self.text_head = ml.Linear(model_dim, self.number_text_tokens) - self.mel_head = ml.Linear(model_dim, self.number_mel_codes) - - # Initialize the embeddings per the GPT-2 scheme - embeddings = [self.text_embedding] - if use_mel_codes_as_input: - embeddings.append(self.mel_embedding) - for module in embeddings: - module.weight.data.normal_(mean=0.0, std=0.02) - - if freeze_conditioning_encoder: - print(" > Freezing conditioning encoder.") - for p in self.conditioning_encoder.parameters(): - p.requires_grad = False - p.DO_NOT_TRAIN = True - - if freeze_everything_but_position_embeddings: - for p in self.parameters(): - p.requires_grad = False - p.DO_NOT_TRAIN = True - for m in [self.mel_pos_embedding, self.text_pos_embedding]: - for p in m.parameters(): - del p.DO_NOT_TRAIN - p.requires_grad = True - - def get_grad_norm_parameter_groups(self): - return { - "conditioning_encoder": list(self.conditioning_encoder.parameters()), - "gpt": list(self.gpt.parameters()), - "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), - } - - def post_init_gpt2_config(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False): - seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 - gpt_config = GPT2Config( - vocab_size=self.max_mel_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True, - ) - self.inference_model = GPT2InferenceModel( - gpt_config, - self.gpt, - self.mel_pos_embedding, - self.mel_embedding, - self.final_norm, - self.mel_head, - kv_cache=kv_cache, - ) - # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) - self.gpt.wte = self.mel_embedding - - if use_deepspeed: - # init deepspeed inference engine - if use_deepspeed_f16: - self.gpt.wte = self.mel_embedding.half() - self.gpt.wpe = self.mel_pos_embedding.half() - self.ds_engine = deepspeed.init_inference( - model=self.inference_model.half(), # Transformers models - mp_size=1, # Number of GPU - dtype=torch.float16 if use_deepspeed_f16 else torch.float32, # desired data type of output - replace_method="auto", # Lets DS autmatically identify the layer to replace - replace_with_kernel_inject=True, # replace the model with the kernel injector - ) - self.inference_model = self.ds_engine.module.eval() - - def build_aligned_inputs_and_targets(self, input, start_token, stop_token): - inp = F.pad(input, (1, 0), value=start_token) - tar = F.pad(input, (0, 1), value=stop_token) - return inp, tar - - def set_mel_padding(self, mel_input_tokens, mel_lengths): - """ - Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in - that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required - preformatting to create a working TTS model. - """ - # Set padding areas within MEL (currently it is coded with the MEL code for ). - for b in range(len(mel_lengths)): - actual_end = mel_lengths[b] - if actual_end < mel_input_tokens.shape[-1]: - mel_input_tokens[b, actual_end:] = self.stop_mel_token - return mel_input_tokens - - def get_logits( - self, - speech_conditioning_inputs, - first_inputs, - first_head, - second_inputs=None, - second_head=None, - prompt=None, - get_attns=False, - return_latent=False, - attn_mask_text=None, - attn_mask_mel=None, - ): - if prompt is not None and speech_conditioning_inputs is not None: - offset = speech_conditioning_inputs.shape[1] + prompt.shape[1] - if second_inputs is not None: - emb = torch.cat( - [speech_conditioning_inputs, prompt, first_inputs, second_inputs], - dim=1, - ) - else: - emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1) - elif speech_conditioning_inputs is not None: - offset = speech_conditioning_inputs.shape[1] - if second_inputs is not None: - emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) - elif prompt is not None: - offset = prompt.shape[1] - if second_inputs is not None: - emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) - else: - emb = torch.cat([prompt, first_inputs], dim=1) - - # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): - attn_mask = None - if attn_mask_text is not None: - attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) - if prompt is not None: - attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) - attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) - - gpt_out = self.gpt( - inputs_embeds=emb, - return_dict=True, - output_attentions=get_attns, - attention_mask=attn_mask, - ) - - if get_attns: - return gpt_out.attentions - - enc = gpt_out.last_hidden_state[:, offset:] - enc = self.final_norm(enc) - - if return_latent: - return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] - - first_logits = enc[:, : first_inputs.shape[1]] - first_logits = first_head(first_logits) - first_logits = first_logits.permute(0, 2, 1) - if second_inputs is not None: - second_logits = enc[:, -second_inputs.shape[1] :] - second_logits = second_head(second_logits) - second_logits = second_logits.permute(0, 2, 1) - return first_logits, second_logits - else: - return first_logits - - def get_conditioning(self, speech_conditioning_input): - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - conds = conds.mean(dim=1) - return conds - - def get_prompts(self, prompt_codes): - """ - Create a prompt from the mel codes. This is used to condition the model on the mel codes. - Pad the prompt with start and stop mel tokens. - """ - prompt = prompt_codes - if self.training: - prompt_len = random.randint(1, 9) # in secs - prompt_len = prompt_len * 24 # in frames - - if prompt_codes.shape[1] < prompt_len: - prompt_len = prompt_codes.shape[-1] - start = 0 - else: - start = random.randint(0, prompt_codes.shape[-1] - prompt_len) - - prompt = prompt_codes[:, start : start + prompt_len] - - # add start and stop tokens - prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) - prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) - return prompt - - # def get_prompts(self, prompt_codes): - # """ - # Create a prompt from the mel codes. This is used to condition the model on the mel codes. - # Pad the prompt with start and stop mel tokens. - # """ - # prompt = prompt_codes - # if self.training: - # max_prompt_len = 9 * 24 - # if prompt_codes.shape[1] < max_prompt_len: - # prompt = prompt_codes - # else: - # start = random.randint(0, prompt_codes.shape[1] - max_prompt_len) - # prompt = prompt_codes[:, start : start + max_prompt_len] - - # # add start and stop tokens - # prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) - # prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) - # return prompt - - def forward( - self, - speech_conditioning_input, - text_inputs, - text_lengths, - mel_codes, - wav_lengths, - prompt_codes, - loss_weights=None, - text_first=True, - return_attentions=False, - return_latent=False, - ): - """ - Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode - (actuated by `text_first`). - - speech_conditioning_input: MEL float tensor, (b,80,s) - text_inputs: long tensor, (b,t) - text_lengths: long tensor, (b,) - mel_inputs: long tensor, (b,m) - wav_lengths: long tensor, (b,) - - If return_attentions is specified, only logits are returned. - If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. - """ - - # ❗ FIXIT - speech_conditioning_input = None - if self.max_conditioning_inputs == 0: - assert ( - speech_conditioning_input is None - ), " ❗ speech_conditioning_input is not None, but max_conditioning_inputs == 0" - - max_text_len = text_lengths.max() - # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values - # Like [..., 186, 45, 45, 83] where actually it should end with 186. - # We take last 3 codes to prevent abrupt ending of the audio. - # TODO: This is might need some testing. - mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3 - - # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. - max_mel_len = mel_lengths.max() - - if max_mel_len > mel_codes.shape[-1]: - mel_codes = F.pad(mel_codes, (0, max_mel_len - mel_codes.shape[-1])) - - # mel_lengths[mel_lengths >= max_mel_len] = max_mel_len - - # silence aware lengths, skip the silence tokens at the end of the mel codes. - silence = True - for idx, l in enumerate(mel_lengths): - length = l.item() - while silence: - if mel_codes[idx, length - 1] != 83: - break - length -= 1 - mel_lengths[idx] = length - - # Lovely assertions - assert ( - max_mel_len <= mel_codes.shape[-1] - ), f" ❗ max_mel_len ({max_mel_len}) > mel_codes.shape[-1] ({mel_codes.shape[-1]})" - assert ( - max_text_len <= text_inputs.shape[-1] - ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" - - # Append stop token to text inputs - text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) - - # Append silence token to mel codes - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) - - # Pad mel codes with STOP_MEL_TOKEN - mel_codes = self.set_mel_padding(mel_codes, mel_lengths) - - # Compute speech conditioning input - conds = None - if speech_conditioning_input is not None: - if not return_latent: - # Compute speech conditioning input - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - else: - # already computed - conds = speech_conditioning_input.unsqueeze(1) - - # Build input and target tensors - # Prepend start token to inputs and append stop token to targets - text_inputs, text_targets = self.build_aligned_inputs_and_targets( - text_inputs, self.start_text_token, self.stop_text_token - ) - mel_codes, mel_targets = self.build_aligned_inputs_and_targets( - mel_codes, self.start_mel_token, self.stop_mel_token - ) - - # Set attn_mask - attn_mask_text = None - attn_mask_mel = None - if not return_latent: - attn_mask_text = torch.ones( - text_inputs.shape[0], - text_inputs.shape[1], - dtype=torch.bool, - device=text_inputs.device, - ) - attn_mask_mel = torch.ones( - mel_codes.shape[0], - mel_codes.shape[1], - dtype=torch.bool, - device=mel_codes.device, - ) - - for idx, l in enumerate(text_lengths): - attn_mask_text[idx, l + 1 :] = 0.0 - - for idx, l in enumerate(mel_lengths): - attn_mask_mel[idx, l + 1 :] = 0.0 - - # Compute text embeddings + positional embeddings - # print(" > text input latent:", text_inputs) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - # Compute mel embeddings + positional embeddings - mel_emb = self.mel_embedding(mel_codes) + self.mel_pos_embedding(mel_codes) - - # Compute prompt embeddings + positional embeddings - prompt = self.get_prompts(prompt_codes) - - prompt_emb = self.mel_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach() - - # Get logits - sub = -4 # don't ask me why 😄 - if self.training: - sub = -1 - text_logits, mel_logits = self.get_logits( - conds, - text_emb, - self.text_head, - mel_emb, - self.mel_head, - prompt=prompt_emb, - get_attns=return_attentions, - return_latent=return_latent, - attn_mask_text=attn_mask_text, - attn_mask_mel=attn_mask_mel, - ) - if return_latent: - return mel_logits[:, :sub] # sub to prevent bla. - - if return_attentions: - return mel_logits - - # Set paddings to -1 to ignore them in loss - for idx, l in enumerate(text_lengths): - text_targets[idx, l + 1 :] = -1 - - for idx, l in enumerate(mel_lengths): - mel_targets[idx, l + 1 :] = -1 - - # check if stoptoken is in every row of mel_targets - assert (mel_targets == self.stop_mel_token).sum() >= mel_targets.shape[ - 0 - ], f" ❗ mel_targets does not contain stop token ({self.stop_mel_token}) in every row." - - # Compute losses - loss_text = F.cross_entropy( - text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing - ) - loss_mel = F.cross_entropy( - mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing - ) - - # if loss_weights is not None: - # loss_text = loss_text * loss_weights[:, None] - # loss_mel = loss_mel * loss_weights[:, None] - return loss_text.mean(), loss_mel.mean(), mel_logits - - def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): - """ - Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the - model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). - """ - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) - - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - - text_inputs, text_targets = self.build_aligned_inputs_and_targets( - text_inputs, self.start_text_token, self.stop_text_token - ) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding - text_logits = self.get_logits(conds, text_emb, self.text_head) - loss_text = F.cross_entropy(text_logits, text_targets.long()) - return loss_text.mean() - - def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None): - """ - Performs autoregressive modeling on only speech data. - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f"{mel_codes.shape[1]}" - - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) - mel_codes = self.set_mel_padding(mel_codes, wav_lengths) - if raw_mels is not None: - raw_mels = raw_mels[:, :, : max_mel_len * 4] - - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - if self.average_conditioning_embeddings: - conds = conds.mean(dim=1).unsqueeze(1) - - mel_codes, mel_targets = self.build_aligned_inputs_and_targets( - mel_codes, self.start_mel_token, self.stop_mel_token - ) - if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 4)) - else: - mel_inp = mel_codes - mel_emb = self.mel_embedding(mel_inp) - mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding - mel_logits = self.get_logits(conds, mel_emb, self.mel_head) - loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) - return loss_mel.mean() - - def get_generator(self, fake_inputs, **hf_generate_kwargs): - return self.inference_model.generate_stream( - fake_inputs, - bos_token_id=self.start_mel_token, - pad_token_id=self.stop_mel_token, - eos_token_id=self.stop_mel_token, - max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, - do_stream=True, - **hf_generate_kwargs, - ) - - def compute_embeddings( - self, - speech_conditioning_latent, - text_inputs, - input_tokens=None, - prompt_codes=None, - pad_input_text=False, - ): - if pad_input_text and text_inputs.shape[1] < 250: - text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) - else: - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) - - emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - print(" > Text inputs:", text_inputs) - if prompt_codes is not None: - prompt_codes = self.get_prompts(prompt_codes) - prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) - print(" > Prompt inputs:", prompt_codes) - print(" > Prompt inputs shape:", prompt_codes.shape) - emb = torch.cat([prompt_emb, emb], dim=1) - - if speech_conditioning_latent is not None: - conds = speech_conditioning_latent.unsqueeze(1) - emb = torch.cat([conds, emb], dim=1) - - self.inference_model.store_prefix_emb(emb) - - fake_inputs = torch.full( - ( - emb.shape[0], - emb.shape[1] + 1, # +1 for the start_mel_token - ), - fill_value=1, - dtype=torch.long, - device=text_inputs.device, - ) - fake_inputs[:, -1] = self.start_mel_token - - if input_tokens is not None: - fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) - return fake_inputs - - def inference_speech( - self, - speech_conditioning_latent, - text_inputs, - input_tokens=None, - prompt_codes=None, - pad_input_text=False, - **hf_generate_kwargs, - ): - if pad_input_text and text_inputs.shape[1] < 250: - text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) - else: - text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) - text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) - - emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) - - print(" > Text inputs:", text_inputs) - if prompt_codes is not None: - prompt_codes = self.get_prompts(prompt_codes) - prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) - print(" > Prompt inputs:", prompt_codes) - print(" > Prompt inputs shape:", prompt_codes.shape) - emb = torch.cat([prompt_emb, emb], dim=1) - - if speech_conditioning_latent is not None: - conds = speech_conditioning_latent.unsqueeze(1) - emb = torch.cat([conds, emb], dim=1) - - self.inference_model.store_prefix_emb(emb) - - fake_inputs = torch.full( - ( - emb.shape[0], - emb.shape[1] + 1, # +1 for the start_mel_token - ), - fill_value=1, - dtype=torch.long, - device=text_inputs.device, - ) - fake_inputs[:, -1] = self.start_mel_token - - if input_tokens is not None: - fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) - - gen = self.inference_model.generate( - fake_inputs, - bos_token_id=self.start_mel_token, - pad_token_id=self.stop_mel_token, - eos_token_id=self.stop_mel_token, - max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, - **hf_generate_kwargs, - ) - if "return_dict_in_generate" in hf_generate_kwargs: - return gen.sequences[:, fake_inputs.shape[1] :], gen - return gen[:, fake_inputs.shape[1] :] - - # Turns the (utterly insane) output of HF.generate() into a far more sane output: - # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence - def make_hf_generate_attentions_sane(self, attentions): - layers = [[] for _ in range(len(attentions[0]))] - full_attention_size = attentions[-1][0].shape[-1] - for i, gen in enumerate(attentions): - for j, lyr in enumerate(gen): - layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) - catted = [] - for lyr in layers: - catted.append(torch.cat(lyr, dim=2)) - return catted - - def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds): - """ - This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice. - """ - text_padding = num_conds + 2 - num_text = text.shape[-1] - num_context = num_text + text_padding - assert num_context + 1 == attentions[0][0].shape[-1] - attentions = self.make_hf_generate_attentions_sane(attentions) - results = [torch.empty_like(codes) for _ in range(len(attentions))] - for l, layer in enumerate(attentions): - dec_context = layer[:, :, num_context:, :] - # Mask out everything that isn't text (including the start token, which gets a LOT of attention) - dec_context[:, :, :, : text_padding + 1] = 0 - dec_context[:, :, :, num_context:] = 0 - for h in range(dec_context.shape[1]): - dec_context_indices = torch.argmax(dec_context[0, h], dim=-1) - print(f"layer_{l};head_{h}: " + str(dec_context_indices)) - for t, att_tok in enumerate(attentions): - combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) - for lyr in att_tok: - token_to_text_attentions = lyr[:, :, -1, text_padding : (text_padding + num_text)].sum(dim=1) - combined_attention_weights = combined_attention_weights + token_to_text_attentions - break - most_attended_text_token = combined_attention_weights.argmax(dim=-1) - results[:, t] = most_attended_text_token - eos_token_mask = codes != self.stop_mel_token - return results * eos_token_mask - - -@register_model -def register_unified_voice_prompt(opt_net, opt): - return UnifiedVoice(**opt_get(opt_net, ["kwargs"], {})) - - -if __name__ == "__main__": - gpt = UnifiedVoice( - model_dim=256, - heads=4, - train_solo_embeddings=True, - use_mel_codes_as_input=True, - max_conditioning_inputs=4, - freeze_everything_but_position_embeddings=True, - ) - l = gpt( - torch.randn(2, 3, 80, 800), - torch.randint(high=256, size=(2, 120)), - torch.tensor([32, 120]), - torch.randint(high=8192, size=(2, 250)), - torch.tensor([250 * 256, 195 * 256]), - ) - # gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py new file mode 100644 index 00000000..6439b455 --- /dev/null +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -0,0 +1,742 @@ +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm +import torchaudio + +from TTS.utils.io import load_fsspec + + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + cond_in_each_up_layer=False, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + + # initial upsampling layers + self.conv_pre = weight_norm( + Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm( + Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias) + ) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad( + c, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.forward(c) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + + +class ResNetSpeakerEncoder(nn.Module): + """This is copied from 🐸TTS to remove it from the dependencies. + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + def load_checkpoint( + self, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + print(" > Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"]) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion + +class HifiDecoder(torch.nn.Module): + def __init__( + self, + input_sample_rate=22050, + output_sample_rate=24000, + output_hop_length=256, + ar_mel_length_compression=1024, + decoder_input_dim=1024, + resblock_type_decoder="1", + resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_kernel_sizes_decoder=[3, 7, 11], + upsample_rates_decoder=[8, 8, 2, 2], + upsample_initial_channel_decoder=512, + upsample_kernel_sizes_decoder=[16, 16, 4, 4], + d_vector_dim=512, + cond_d_vector_in_each_upsampling_layer=True, + speaker_encoder_audio_config={ + "fft_size": 512, + "win_length": 400, + "hop_length": 160, + "sample_rate": 16000, + "preemphasis": 0.97, + "num_mels": 64, + }, + ): + super().__init__() + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.output_hop_length = output_hop_length + self.ar_mel_length_compression = ar_mel_length_compression + self.speaker_encoder_audio_config = speaker_encoder_audio_config + self.waveform_decoder = HifiganGenerator( + decoder_input_dim, + 1, + resblock_type_decoder, + resblock_dilation_sizes_decoder, + resblock_kernel_sizes_decoder, + upsample_kernel_sizes_decoder, + upsample_initial_channel_decoder, + upsample_rates_decoder, + inference_padding=0, + cond_channels=d_vector_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, + ) + self.speaker_encoder = ResNetSpeakerEncoder( + input_dim=64, + proj_dim=512, + log_input=True, + use_torch_spec=True, + audio_config=speaker_encoder_audio_config, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, latents, g=None): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + + z = torch.nn.functional.interpolate( + latents.transpose(1, 2), + scale_factor=[self.ar_mel_length_compression / self.output_hop_length], + mode="linear", + ).squeeze(1) + # upsample to the right sr + if self.output_sample_rate != self.input_sample_rate: + z = torch.nn.functional.interpolate( + z, + scale_factor=[self.output_sample_rate / self.input_sample_rate], + mode="linear", + ).squeeze(0) + o = self.waveform_decoder(z, g=g) + return o + + @torch.no_grad() + def inference(self, c, g): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + return self.forward(c, g=g) + + def load_checkpoint( + self, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # remove unused keys + state = state["model"] + states_keys = list(state.keys()) + for key in states_keys: + if "waveform_decoder." not in key and "speaker_encoder." not in key: + del state[key] + + self.load_state_dict(state) + if eval: + self.eval() + assert not self.training + self.waveform_decoder.remove_weight_norm() diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py new file mode 100644 index 00000000..8bdd2291 --- /dev/null +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -0,0 +1,1057 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +from transformers import ( + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + StoppingCriteriaList, + DisjunctiveConstraint, + BeamSearchScorer, + PhrasalConstraint, + ConstrainedBeamSearchScorer, + PreTrainedModel, +) +import numpy as np +import random +import warnings +import inspect +from transformers.generation.utils import GenerateOutput, SampleOutput, logger +import torch +from typing import Callable, List, Optional, Union +from torch import nn +import torch.distributed as dist +import copy + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + #setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config( + self.config + ) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set( + inspect.signature(self.forward).parameters.keys() + ) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if ( + model_kwargs.get("attention_mask", None) is None + and requires_attention_mask + and accepts_attention_mask + ): + model_kwargs[ + "attention_mask" + ] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) + > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = ( + kwargs.get("max_length") is None + and generation_config.max_length is not None + ) + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = ( + generation_config.max_new_tokens + input_ids_seq_length + ) + elif ( + not has_default_max_length and generation_config.max_new_tokens is not None + ): + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if ( + generation_config.min_length is not None + and generation_config.min_length > generation_config.max_length + ): + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = ( + "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + ) + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None + or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError( + "`num_beam_groups` has to be smaller or equal to `num_beams`" + ) + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams + * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError( + "`num_beams` should be divisible by `num_beam_groups` for group beam search." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + has_default_typical_p = ( + kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + ) + if not has_default_typical_p: + raise ValueError( + "Decoder argument `typical_p` is not supported with beam groups." + ) + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError( + "`num_return_sequences` has to be smaller or equal to `num_beams`." + ) + + if stopping_criteria.max_length is None: + raise ValueError( + "`max_length` needs to be a stopping_criteria for now." + ) + + if generation_config.num_beams <= 1: + raise ValueError( + "`num_beams` needs to be greater than 1 for constrained generation." + ) + + if generation_config.do_sample: + raise ValueError( + "`do_sample` needs to be false for constrained generation." + ) + + if ( + generation_config.num_beam_groups is not None + and generation_config.num_beam_groups > 1 + ): + raise ValueError( + "`num_beam_groups` not supported yet for constrained generation." + ) + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + not isinstance(token_ids, list) for token_ids in word_ids + ): + typeerror() + if any( + any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in token_ids + ) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any( + (not isinstance(token_id, int) or token_id < 0) + for token_id in word_ids + ): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + logits_warper = ( + logits_warper if logits_warper is not None else LogitsProcessorList() + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + eos_token_id = ( + eos_token_id + if eos_token_id is not None + else self.generation_config.eos_token_id + ) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul( + (sum(next_tokens != i for i in eos_token_id)).long() + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import PreTrainedModel + from transformers import AutoTokenizer, AutoModelForCausalLM + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained( + "bigscience/bloom-560m", torch_dtype=torch.float16 + ) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer( + prompt_text, return_tensors="pt", add_special_tokens=False + ).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 0fad8133..a2795289 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -171,17 +171,6 @@ def multilingual_cleaners(text, lang): return text -def english_cleaners(text): - """Pipeline for English text, including number and abbreviation expansion.""" - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - text = text.replace('"', "") - return text - - def remove_extraneous_punctuation(word): replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"} replace = re.compile( @@ -195,32 +184,6 @@ def remove_extraneous_punctuation(word): return word -def expand_numbers(text): - return normalize_numbers(text) - - -def lowercase(text): - return text.lower() - - -_whitespace_re = re.compile(r"\s+") - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, " ", text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - """Basic pipeline that lowercases and collapses whitespace without transliteration.""" - text = lowercase(text) - text = collapse_whitespace(text) - return text - - def arabic_cleaners(text): text = lowercase(text) text = collapse_whitespace(text) @@ -261,7 +224,10 @@ class VoiceBpeTokenizer: txt = " ".join([result["kana"] for result in results]) txt = basic_cleaners(txt) elif lang == "en": + if txt[:4] == "[en]": + txt = txt[4:] txt = english_cleaners(txt) + txt = "[en]" + txt elif lang == "ar": txt = arabic_cleaners(txt) elif lang == "zh-cn": diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index c0a00c66..b1cf886b 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -726,8 +726,8 @@ class DelightfulTTS(BaseTTSE2E): def pitch_std(self): return self.acoustic_model.pitch_std - @pitch_mean.setter - def pitch_std(self, value): # pylint: disable=function-redefined + @pitch_std.setter + def pitch_std(self, value): self.acoustic_model.pitch_std = value @property @@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E): scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) return [scheduler_D, scheduler_G] - def on_train_step_start(self, trainer): - """Schedule binary loss weight.""" - self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 - def on_epoch_end(self, trainer): # pylint: disable=unused-argument # stop updating mean and var # TODO: do the same for F0 @@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E): Args: trainer (Trainer): Trainer object. """ + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 self.train_disc = ( # pylint: disable=attribute-defined-outside-init trainer.total_steps_done >= self.config.steps_to_start_discriminator ) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 0836870e..2b480744 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.vocoder import UnivNetGenerator +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec +init_stream_support() def load_audio(audiopath, sr=22050): """ @@ -195,13 +198,12 @@ class XttsArgs(Coqpit): Args: gpt_batch_size (int): The size of the auto-regressive batch. enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. - lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False. kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True. For GPT model: ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -231,12 +233,12 @@ class XttsArgs(Coqpit): gpt_batch_size: int = 1 enable_redaction: bool = False - lazy_load: bool = True kv_cache: bool = True gpt_checkpoint: str = None clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 + use_hifigan: bool = True # XTTS GPT Encoder params tokenizer_file: str = "" @@ -266,6 +268,15 @@ class XttsArgs(Coqpit): diff_layer_drop: int = 0 diff_unconditioned_percentage: int = 0 + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + ar_mel_length_compression: int = 1024 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + # constants duration_const: int = 102400 @@ -285,7 +296,6 @@ class Xtts(BaseTTS): def __init__(self, config: Coqpit): super().__init__(config, ap=None, tokenizer=None) - self.lazy_load = self.args.lazy_load self.mel_stats_path = None self.config = config self.gpt_checkpoint = self.args.gpt_checkpoint @@ -295,7 +305,6 @@ class Xtts(BaseTTS): self.tokenizer = VoiceBpeTokenizer() self.gpt = None - self.diffusion_decoder = None self.init_models() self.register_buffer("mel_stats", torch.ones(80)) @@ -322,40 +331,39 @@ class Xtts(BaseTTS): stop_audio_token=self.args.gpt_stop_audio_token, ) - self.diffusion_decoder = DiffusionTts( - model_channels=self.args.diff_model_channels, - num_layers=self.args.diff_num_layers, - in_channels=self.args.diff_in_channels, - out_channels=self.args.diff_out_channels, - in_latent_channels=self.args.diff_in_latent_channels, - in_tokens=self.args.diff_in_tokens, - dropout=self.args.diff_dropout, - use_fp16=self.args.diff_use_fp16, - num_heads=self.args.diff_num_heads, - layer_drop=self.args.diff_layer_drop, - unconditioned_percentage=self.args.diff_unconditioned_percentage, - ) - self.vocoder = UnivNetGenerator() + if self.args.use_hifigan: + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.ar_mel_length_compression, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + else: + self.diffusion_decoder = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ) + self.vocoder = UnivNetGenerator() @property def device(self): return next(self.parameters()).device - @contextmanager - def lazy_load_model(self, model): - """Context to load a model on demand. - - Args: - model (nn.Module): The model to be loaded. - """ - if self.lazy_load: - yield model - else: - m = model.to(self.device) - yield m - m = model.cpu() - + @torch.inference_mode() def get_gpt_cond_latents(self, audio_path: str, length: int = 3): """Compute the conditioning latents for the GPT model from the given audio. @@ -370,6 +378,7 @@ class Xtts(BaseTTS): cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) return cond_latent.transpose(1, 2) + @torch.inference_mode() def get_diffusion_cond_latents( self, audio_path, @@ -389,20 +398,33 @@ class Xtts(BaseTTS): ) diffusion_conds.append(cond_mel) diffusion_conds = torch.stack(diffusion_conds, dim=1) - with self.lazy_load_model(self.diffusion_decoder) as diffusion: - diffusion_latent = diffusion.get_conditioning(diffusion_conds) + diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds) return diffusion_latent + @torch.inference_mode() + def get_speaker_embedding( + self, + audio_path + ): + audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) + speaker_embedding = self.hifigan_decoder.speaker_encoder.forward( + audio.to(self.device), l2_norm=True + ).unsqueeze(-1).to(self.device) + return speaker_embedding + def get_conditioning_latents( self, audio_path, gpt_cond_len=3, - ): + ): + speaker_embedding = None + diffusion_cond_latents = None + if self.args.use_hifigan: + speaker_embedding = self.get_speaker_embedding(audio_path) + else: + diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] - diffusion_cond_latents = self.get_diffusion_cond_latents( - audio_path, - ) - return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device) + return gpt_cond_latents, diffusion_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -447,10 +469,10 @@ class Xtts(BaseTTS): "decoder_sampler": config.decoder_sampler, } settings.update(kwargs) # allow overriding of preset settings with kwargs - return self.inference(text, ref_audio_path, language, **settings) + return self.full_inference(text, ref_audio_path, language, **settings) - @torch.no_grad() - def inference( + @torch.inference_mode() + def full_inference( self, text, ref_audio_path, @@ -525,6 +547,54 @@ class Xtts(BaseTTS): Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ + ( + gpt_cond_latent, + diffusion_conditioning, + speaker_embedding + ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + decoder_iterations=decoder_iterations, + cond_free=cond_free, + cond_free_k=cond_free_k, + diffusion_temperature=diffusion_temperature, + decoder_sampler=decoder_sampler, + **hf_generate_kwargs, + ) + + @torch.inference_mode() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Decoder inference + decoder_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + decoder_sampler="ddim", + **hf_generate_kwargs, + ): text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -532,74 +602,147 @@ class Xtts(BaseTTS): text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - ( - gpt_cond_latent, - diffusion_conditioning, - ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) - - diffuser = load_discrete_vocoder_diffuser( - desired_diffusion_steps=decoder_iterations, - cond_free=cond_free, - cond_free_k=cond_free_k, - sampler=decoder_sampler, - ) + if not self.args.use_hifigan: + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=decoder_iterations, + cond_free=cond_free, + cond_free_k=cond_free_k, + sampler=decoder_sampler, + ) with torch.no_grad(): - self.gpt = self.gpt.to(self.device) - with self.lazy_load_model(self.gpt) as gpt: - gpt_codes = gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) - - with self.lazy_load_model(self.gpt) as gpt: - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) - silence_token = 83 - ctokens = 0 - for k in range(gpt_codes.shape[-1]): - if gpt_codes[0, k] == silence_token: - ctokens += 1 - else: - ctokens = 0 - if ctokens > 8: - gpt_latents = gpt_latents[:, :k] - break - - with self.lazy_load_model(self.diffusion_decoder) as diffusion: + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + silence_token = 83 + ctokens = 0 + for k in range(gpt_codes.shape[-1]): + if gpt_codes[0, k] == silence_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: + gpt_latents = gpt_latents[:, :k] + break + + if self.args.use_hifigan: + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + else: mel = do_spectrogram_diffusion( - diffusion, + self.diffusion_decoder, diffuser, gpt_latents, diffusion_conditioning, temperature=diffusion_temperature, ) - with self.lazy_load_model(self.vocoder) as vocoder: - wav = vocoder.inference(mel) + wav = self.vocoder.inference(mel) return {"wav": wav.cpu().numpy().squeeze()} + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + @torch.inference_mode() + def inference_stream( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Decoder inference + **hf_generate_kwargs, + ): + assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." + text = f"[{language}]{text.strip().lower()}" + text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + def forward(self): raise NotImplementedError("XTTS Training is not implemented") @@ -616,7 +759,14 @@ class Xtts(BaseTTS): super().eval() def load_checkpoint( - self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True + self, + config, + checkpoint_dir=None, + checkpoint_path=None, + vocab_path=None, + eval=True, + strict=True, + use_deepspeed=False, ): """ Loads a checkpoint from disk and initializes the model's state and tokenizer. @@ -626,7 +776,7 @@ class Xtts(BaseTTS): checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. vocab_path (str, optional): The path to the vocabulary file. Defaults to None. - eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. Returns: @@ -636,19 +786,26 @@ class Xtts(BaseTTS): model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") - if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")): - self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json")) + if os.path.exists(vocab_path): + self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) self.init_models() if eval: self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) - self.load_state_dict(load_fsspec(model_path)["model"], strict=strict) + + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"] + for key in list(checkpoint.keys()): + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + self.load_state_dict(checkpoint, strict=strict) if eval: - self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval() + if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval() + if hasattr(self, "vocoder"): self.vocoder.eval() + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() - self.diffusion_decoder.eval() - self.vocoder.eval() def train_step(self): raise NotImplementedError("XTTS Training is not implemented") diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 6e082297..955eeb9b 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -6,6 +6,7 @@ from pathlib import Path from shutil import copyfile, rmtree from typing import Dict, List, Tuple +import fsspec import requests from tqdm import tqdm @@ -315,11 +316,51 @@ class ModelManager(object): """Check if the user has agreed to the terms of service""" if "tos_required" in model_item and model_item["tos_required"]: tos_path = os.path.join(model_full_path, "tos_agreed.txt") - if os.path.exists(tos_path): + if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1": return True return False return True + def create_dir_and_download_model(self, model_name, model_item, output_path): + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + print(f" > Downloading model to {output_path}") + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + + except requests.RequestException as e: + print(f" > Failed to download the model file to {output_path}") + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) + + def check_if_configs_are_equal(self, model_name, model_item, output_path): + with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: + config_local = json.load(f) + remote_url = None + for url in model_item["hf_url"]: + if "config.json" in url: + remote_url = url + break + + with fsspec.open(remote_url, "r", encoding="utf-8") as f: + config_remote = json.load(f) + + if not config_local == config_remote: + print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + print(f" > {model_name} is already downloaded.") + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -338,28 +379,18 @@ class ModelManager(object): # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): - print(f" > {model_name} is already downloaded.") + # if the configs are different, redownload it + # ToDo: we need a better way to handle it + if "xtts_v1" in model_name: + try: + self.check_if_configs_are_equal(model_name, model_item, output_path) + except: + pass + else: + print(f" > {model_name} is already downloaded.") else: - os.makedirs(output_path, exist_ok=True) - # handle TOS - if not self.tos_agreed(model_item, output_path): - if not self.ask_tos(output_path): - os.rmdir(output_path) - raise Exception(" [!] You must agree to the terms of service to use this model.") - print(f" > Downloading model to {output_path}") - try: - if "fairseq" in model_name: - self.download_fairseq_model(model_name, output_path) - elif "github_rls_url" in model_item: - self._download_github_model(model_item, output_path) - elif "hf_url" in model_item: - self._download_hf_model(model_item, output_path) + self.create_dir_and_download_model(model_name, model_item, output_path) - except requests.RequestException as e: - print(f" > Failed to download the model file to {output_path}") - rmtree(output_path) - raise e - self.print_model_license(model_item=model_item) # find downloaded files output_model_path = output_path output_config_path = None @@ -498,13 +529,16 @@ class ModelManager(object): print(f" > Error: Bad zip file - {file_url}") raise zipfile.BadZipFile # pylint: disable=raise-missing-from # move the files to the outer path - for file_path in z.namelist()[1:]: + for file_path in z.namelist(): src_path = os.path.join(output_folder, file_path) - dst_path = os.path.join(output_folder, os.path.basename(file_path)) - if src_path != dst_path: - copyfile(src_path, dst_path) - # remove the extracted folder - rmtree(os.path.join(output_folder, z.namelist()[0])) + if os.path.isfile(src_path): + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove redundant (hidden or not) folders + for file_path in z.namelist(): + if os.path.isdir(os.path.join(output_folder, file_path)): + rmtree(os.path.join(output_folder, file_path)) @staticmethod def _download_tar_file(file_url, output_folder, progress_bar): diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index 5684a88e..e799cc2a 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -116,12 +116,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - def shift_1d(x): x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] return x diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md index 85a3afba..ff6bcf97 100644 --- a/docs/source/models/xtts.md +++ b/docs/source/models/xtts.md @@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml) Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai). You can also mail us at info@coqui.ai. -Using 🐸TTS API: +### Inference +#### 🐸TTS API ```python from TTS.api import TTS @@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t file_path="output.wav", speaker_wav="/path/to/target/speaker.wav", language="en") - -# generate speech by cloning a voice using custom settings -tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - file_path="output.wav", - speaker_wav="/path/to/target/speaker.wav", - language="en", - decoder_iterations=30) ``` -Using 🐸TTS Command line: +#### 🐸TTS Command line ```console tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \ @@ -58,25 +52,85 @@ Using 🐸TTS Command line: --use_cuda true ``` -Using model directly: +#### model directly + +If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first. + +```console +pip install deepspeed==0.8.3 +``` ```python +import os +import torch +import torchaudio from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts +print("Loading model...") config = XttsConfig() config.load_json("/path/to/xtts/config.json") model = Xtts.init_from_config(config) -model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True) +model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True) +model.cuda() + +print("Computing speaker latents...") +gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") + +print("Inference...") +out = model.inference( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding, + diffusion_conditioning, + temperature=0.7, # Add custom parameters here +) +torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000) +``` + + +#### streaming inference + +Here the goal is to stream the audio as it is being generated. This is useful for real-time applications. +Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster. + + +```python +import os +import time +import torch +import torchaudio +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + +print("Loading model...") +config = XttsConfig() +config.load_json("/path/to/xtts/config.json") +model = Xtts.init_from_config(config) +model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True) model.cuda() -outputs = model.synthesize( +print("Computing speaker latents...") +gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav") + +print("Inference...") +t0 = time.time() +chunks = model.inference_stream( "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", - config, - speaker_wav="/data/TTS-public/_refclips/3.wav", - gpt_cond_len=3, - language="en", + "en", + gpt_cond_latent, + speaker_embedding ) + +wav_chuncks = [] +for i, chunk in enumerate(chunks): + if i == 0: + print(f"Time to first chunck: {time.time() - t0}") + print(f"Received chunk {i} of audio length {chunk.shape[-1]}") + wav_chuncks.append(chunk) +wav = torch.cat(wav_chuncks, dim=0) +torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000) ``` diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index a257b6bf..9acc9929 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -13,15 +13,15 @@ "metadata": {}, "outputs": [], "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", "import os\n", "import sys\n", "import torch\n", "import importlib\n", "import numpy as np\n", - "from tqdm import tqdm as tqdm\n", + "from tqdm import tqdm\n", "from torch.utils.data import DataLoader\n", + "import soundfile as sf\n", + "import pickle\n", "from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.utils.audio import AudioProcessor\n", @@ -33,8 +33,8 @@ "\n", "%matplotlib inline\n", "\n", - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES']='2'" + "# Configure CUDA visibility\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '2'" ] }, { @@ -43,6 +43,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Function to create directories and file names\n", "def set_filename(wav_path, out_path):\n", " wav_file = os.path.basename(wav_path)\n", " file_name = wav_file.split('.')[0]\n", @@ -61,6 +62,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Paths and configurations\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", "DATASET = \"ljspeech\"\n", @@ -73,12 +75,15 @@ "QUANTIZE_BIT = None\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "\n", + "# Check CUDA availability\n", "use_cuda = torch.cuda.is_available()\n", "print(\" > CUDA enabled: \", use_cuda)\n", "\n", + "# Load the configuration\n", "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", - "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" + "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", + "print(C['r'])" ] }, { @@ -87,14 +92,13 @@ "metadata": {}, "outputs": [], "source": [ - "print(C['r'])\n", - "# if the vocabulary was passed, replace the default\n", + "# If the vocabulary was passed, replace the default\n", "if 'characters' in C and C['characters']:\n", " symbols, phonemes = make_symbols(**C.characters)\n", "\n", - "# load the model\n", + "# Load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", - "# TODO: multiple speaker\n", + "# TODO: multiple speakers\n", "model = setup_model(C)\n", "model.load_checkpoint(C, MODEL_FILE, eval=True)" ] @@ -105,11 +109,12 @@ "metadata": {}, "outputs": [], "source": [ + "# Load the preprocessor based on the dataset\n", "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", "dataset = TTSDataset(\n", - " checkpoint[\"config\"][\"r\"],\n", + " C,\n", " C.text_cleaner,\n", " False,\n", " ap,\n", @@ -124,6 +129,24 @@ ")\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize lists for storing results\n", + "file_idxs = []\n", + "metadata = []\n", + "losses = []\n", + "postnet_losses = []\n", + "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", + "\n", + "# Create log file\n", + "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", + "log_file = open(log_file_path, \"w\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -137,83 +160,85 @@ "metadata": {}, "outputs": [], "source": [ - "import pickle\n", - "\n", - "file_idxs = []\n", - "metadata = []\n", - "losses = []\n", - "postnet_losses = []\n", - "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", + "# Start processing with a progress bar\n", "with torch.no_grad():\n", - " for data in tqdm(loader):\n", - " # setup input data\n", - " text_input = data[0]\n", - " text_lengths = data[1]\n", - " linear_input = data[3]\n", - " mel_input = data[4]\n", - " mel_lengths = data[5]\n", - " stop_targets = data[6]\n", - " item_idx = data[7]\n", + " for data in tqdm(loader, desc=\"Processing\"):\n", + " try:\n", + " # setup input data\n", + " text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n", "\n", - " # dispatch data to GPU\n", - " if use_cuda:\n", - " text_input = text_input.cuda()\n", - " text_lengths = text_lengths.cuda()\n", - " mel_input = mel_input.cuda()\n", - " mel_lengths = mel_lengths.cuda()\n", + " # dispatch data to GPU\n", + " if use_cuda:\n", + " text_input = text_input.cuda()\n", + " text_lengths = text_lengths.cuda()\n", + " mel_input = mel_input.cuda()\n", + " mel_lengths = mel_lengths.cuda()\n", "\n", - " mask = sequence_mask(text_lengths)\n", - " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", - " \n", - " # compute loss\n", - " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", - " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", - " losses.append(loss.item())\n", - " postnet_losses.append(loss_postnet.item())\n", + " mask = sequence_mask(text_lengths)\n", + " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", "\n", - " # compute mel specs from linear spec if model is Tacotron\n", - " if C.model == \"Tacotron\":\n", - " mel_specs = []\n", - " postnet_outputs = postnet_outputs.data.cpu().numpy()\n", - " for b in range(postnet_outputs.shape[0]):\n", - " postnet_output = postnet_outputs[b]\n", - " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n", - " postnet_outputs = torch.stack(mel_specs)\n", - " elif C.model == \"Tacotron2\":\n", - " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", - " alignments = alignments.detach().cpu().numpy()\n", + " # compute loss\n", + " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", + " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", + " losses.append(loss.item())\n", + " postnet_losses.append(loss_postnet.item())\n", "\n", - " if not DRY_RUN:\n", - " for idx in range(text_input.shape[0]):\n", - " wav_file_path = item_idx[idx]\n", - " wav = ap.load_wav(wav_file_path)\n", - " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", - " file_idxs.append(file_name)\n", + " # compute mel specs from linear spec if the model is Tacotron\n", + " if C.model == \"Tacotron\":\n", + " mel_specs = []\n", + " postnet_outputs = postnet_outputs.data.cpu().numpy()\n", + " for b in range(postnet_outputs.shape[0]):\n", + " postnet_output = postnet_outputs[b]\n", + " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n", + " postnet_outputs = torch.stack(mel_specs)\n", + " elif C.model == \"Tacotron2\":\n", + " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", + " alignments = alignments.detach().cpu().numpy()\n", "\n", - " # quantize and save wav\n", - " if QUANTIZED_WAV:\n", - " wavq = ap.quantize(wav)\n", - " np.save(wavq_path, wavq)\n", + " if not DRY_RUN:\n", + " for idx in range(text_input.shape[0]):\n", + " wav_file_path = item_idx[idx]\n", + " wav = ap.load_wav(wav_file_path)\n", + " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_idxs.append(file_name)\n", "\n", - " # save TTS mel\n", - " mel = postnet_outputs[idx]\n", - " mel_length = mel_lengths[idx]\n", - " mel = mel[:mel_length, :].T\n", - " np.save(mel_path, mel)\n", + " # quantize and save wav\n", + " if QUANTIZED_WAV:\n", + " wavq = ap.quantize(wav)\n", + " np.save(wavq_path, wavq)\n", "\n", - " metadata.append([wav_file_path, mel_path])\n", + " # save TTS mel\n", + " mel = postnet_outputs[idx]\n", + " mel_length = mel_lengths[idx]\n", + " mel = mel[:mel_length, :].T\n", + " np.save(mel_path, mel)\n", "\n", - " # for wavernn\n", - " if not DRY_RUN:\n", - " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n", - " \n", - " # for pwgan\n", - " with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", + " metadata.append([wav_file_path, mel_path])\n", "\n", - " print(np.mean(losses))\n", - " print(np.mean(postnet_losses))" + " except Exception as e:\n", + " log_file.write(f\"Error processing data: {str(e)}\\n\")\n", + "\n", + " # Calculate and log mean losses\n", + " mean_loss = np.mean(losses)\n", + " mean_postnet_loss = np.mean(postnet_losses)\n", + " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", + " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", + "\n", + "# Close the log file\n", + "log_file.close()\n", + "\n", + "# For wavernn\n", + "if not DRY_RUN:\n", + " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", + "\n", + "# For pwgan\n", + "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", + " for data in metadata:\n", + " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", + "\n", + "# Print mean losses\n", + "print(f\"Mean Loss: {mean_loss}\")\n", + "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 8544bb20..92257530 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,11 @@ [build-system] -requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"] +requires = [ + "setuptools", + "wheel", + "cython~=0.29.30", + "numpy>=1.22.0", + "packaging", +] [flake8] max-line-length=120 @@ -7,25 +13,6 @@ max-line-length=120 [tool.black] line-length = 120 target-version = ['py39'] -exclude = ''' - -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | \.venv - | _build - | buck-out - | build - | dist - )/ - | foo.py # also separately exclude a file named foo.py in - # the root of the project -) -''' [tool.isort] line_length = 120 diff --git a/requirements.txt b/requirements.txt index ae22b333..2837c36e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,26 +5,27 @@ cython==0.29.30 scipy>=1.11.2 torch>=1.7 torchaudio -soundfile -librosa==0.10.0.* +soundfile==0.12.* +librosa==0.10.* +scikit-learn==1.3.0 numba==0.55.1;python_version<"3.9" numba==0.57.0;python_version>="3.9" -inflect==5.6.0 -tqdm -anyascii -pyyaml -fsspec>=2021.04.0 -aiohttp -packaging +inflect==5.6.* +tqdm==4.64.* +anyascii==0.3.* +pyyaml==6.* +fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail +aiohttp==3.8.* +packaging==23.1 # deps for examples -flask +flask==2.* # deps for inference -pysbd +pysbd==0.3.4 # deps for notebooks -umap-learn==0.5.1 -pandas +umap-learn==0.5.* +pandas>=1.4,<2.0 # deps for training -matplotlib +matplotlib==3.7.* # coqui stack trainer # config management @@ -39,14 +40,14 @@ jamo nltk g2pkk>=0.1.1 # deps for bangla -bangla==0.0.2 +bangla bnnumerizer -bnunicodenormalizer==0.1.1 +bnunicodenormalizer #deps for tortoise k_diffusion -einops -transformers +einops==0.6.* +transformers==4.33.* #deps for bark -encodec +encodec==0.1.* # deps for XTTS -unidecode +unidecode==1.3.* diff --git a/scripts/sync_readme.py b/scripts/sync_readme.py new file mode 100644 index 00000000..58428681 --- /dev/null +++ b/scripts/sync_readme.py @@ -0,0 +1,32 @@ +import argparse +from pathlib import Path + + +def replace_between_markers(content, marker: str, replacement: str) -> str: + start_marker = f"\n\n" + end_marker = f"\n\n\n" + start_index = content.index(start_marker) + len(start_marker) + end_index = content.index(end_marker) + content = content[:start_index] + replacement + content[end_index:] + return content + + +def sync_readme(): + ap = argparse.ArgumentParser() + ap.add_argument("--check", action="store_true", default=False) + args = ap.parse_args() + readme_path = Path(__file__).parent.parent / "README.md" + orig_content = readme_path.read_text() + from TTS.bin.synthesize import description + + new_content = replace_between_markers(orig_content, "tts-readme", description.strip()) + if args.check: + if orig_content != new_content: + print("README.md is out of sync; please edit TTS/bin/TTS_README.md and run scripts/sync_readme.py") + exit(42) + readme_path.write_text(new_content) + print("Updated README.md") + + +if __name__ == "__main__": + sync_readme() diff --git a/tests/aux_tests/test_readme.py b/tests/aux_tests/test_readme.py new file mode 100644 index 00000000..32b26fc6 --- /dev/null +++ b/tests/aux_tests/test_readme.py @@ -0,0 +1,9 @@ +import subprocess +import sys +from pathlib import Path + + +def test_readme_up_to_date(): + root = Path(__file__).parent.parent.parent + sync_readme = root / "scripts" / "sync_readme.py" + subprocess.check_call([sys.executable, str(sync_readme), "--check"], cwd=root) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index b7945d6e..dc16d793 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -3,13 +3,19 @@ import glob import os import shutil +import torch + from tests import get_tests_data_path, get_tests_output_path, run_cli from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.manage import ModelManager -MODELS_WITH_SEP_TESTS = ["bark", "xtts"] +MODELS_WITH_SEP_TESTS = [ + "tts_models/multilingual/multi-dataset/bark", + "tts_models/en/multi-dataset/tortoise-v2", + "tts_models/multilingual/multi-dataset/xtts_v1", +] def run_models(offset=0, step=1): @@ -17,7 +23,8 @@ def run_models(offset=0, step=1): print(" > Run synthesizer with all the models.") output_path = os.path.join(get_tests_output_path(), "output.wav") manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False) - model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS] + model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS] + print("Model names:", model_names) for model_name in model_names[offset::step]: print(f"\n > Run - {model_name}") model_path, _, _ = manager.download_model(model_name) @@ -67,23 +74,83 @@ def run_models(offset=0, step=1): def test_xtts(): + """XTTS is too big to run on github actions. We need to test it locally""" output_path = os.path.join(get_tests_output_path(), "output.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") - run_cli( - "yes | " - f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' - f'--speaker_wav "{speaker_wav}" --language_idx "en"' + use_gpu = torch.cuda.is_available() + if use_gpu: + run_cli( + "yes | " + f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True ' + f'--speaker_wav "{speaker_wav}" --language_idx "en"' + ) + else: + run_cli( + "yes | " + f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False ' + f'--speaker_wav "{speaker_wav}" --language_idx "en"' + ) + +def test_xtts_streaming(): + """Testing the new inference_stream method""" + from TTS.tts.configs.xtts_config import XttsConfig + from TTS.tts.models.xtts import Xtts + speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") + model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") + config = XttsConfig() + config.load_json(os.path.join(model_path, "config.json")) + model = Xtts.init_from_config(config) + model.load_checkpoint(config, checkpoint_dir=model_path) + model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + + print("Computing speaker latents...") + gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) + + print("Inference...") + chunks = model.inference_stream( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + "en", + gpt_cond_latent, + speaker_embedding ) + wav_chuncks = [] + for i, chunk in enumerate(chunks): + if i == 0: + assert chunk.shape[-1] > 5000 + wav_chuncks.append(chunk) + assert len(wav_chuncks) > 1 + +def test_tortoise(): + output_path = os.path.join(get_tests_output_path(), "output.wav") + use_gpu = torch.cuda.is_available() + if use_gpu: + run_cli( + f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True' + ) + else: + run_cli( + f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + ) def test_bark(): """Bark is too big to run on github actions. We need to test it locally""" output_path = os.path.join(get_tests_output_path(), "output.wav") - run_cli( - f" tts --model_name tts_models/multilingual/multi-dataset/bark " - f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True' - ) + use_gpu = torch.cuda.is_available() + if use_gpu: + run_cli( + f" tts --model_name tts_models/multilingual/multi-dataset/bark " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True' + ) + else: + run_cli( + f" tts --model_name tts_models/multilingual/multi-dataset/bark " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + ) def test_voice_conversion():