Merge pull request #3057 from coqui-ai/dev

v0.17.8
This commit is contained in:
Eren Gölge 2023-10-11 15:06:20 +02:00 committed by GitHub
commit df2422eb72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 2765 additions and 2251 deletions

View File

@ -0,0 +1,52 @@
name: zoo-tests-tortoise
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.9, "3.10", "3.11"]
experimental: [false]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y git make gcc
sudo apt-get install espeak espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Replace scarf urls
run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise

141
README.md
View File

@ -1,6 +1,7 @@
## 🐸Coqui.ai News
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with uncontrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://tts.readthedocs.io/en/dev/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://tts.readthedocs.io/en/dev/models/bark.html)
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 🐸TTS now supports 🐢Tortoise with faster inference. [Docs](https://tts.readthedocs.io/en/dev/models/tortoise.html)
- 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
@ -111,7 +112,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
- Delightful TTS: [paper](https://arxiv.org/abs/2110.12612)
### End-to-End Models
- ⓍTTS: [blog]()
- ⓍTTS: [blog](https://coqui.ai/blog/tts/open_xtts)
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
- 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418)
- 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts)
@ -293,99 +294,123 @@ api.tts_with_vc_to_file(
```
### Command-line `tts`
<!-- begin-tts-readme -->
Synthesize speech on command line.
You can either use your trained model or choose a model from the provided list.
If you don't specify any models, then it uses LJSpeech based English model.
#### Single Speaker Models
- List provided models:
```
$ tts --list_models
```
```
$ tts --list_models
```
- Get model info (for both tts_models and vocoder_models):
- Query by type/name:
The model_info_by_name uses the name as it from the --list_models.
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
For example:
```
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
```
```
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
```
- Query by type/idx:
The model_query_idx uses the corresponding idx from --list_models.
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
For example:
- Query by type/name:
The model_info_by_name uses the name as it from the --list_models.
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
For example:
```
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
```
- Query by type/idx:
The model_query_idx uses the corresponding idx from --list_models.
```
$ tts --model_info_by_idx tts_models/3
```
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
For example:
```
$ tts --model_info_by_idx tts_models/3
```
- Query info for model info by full name:
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
- Run TTS with default models:
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav
```
- Run a TTS model with its default vocoder model:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
For example:
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
```
- Run with specific TTS and vocoder models from the list:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
For example:
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
```
- Run your own TTS model (Using Griffin-Lim Vocoder):
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
```
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
```
- Run your own TTS and Vocoder models:
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
```
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
```
#### Multi-speaker Models
- List the available speakers and choose a <speaker_id> among them:
```
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
```
```
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
```
- Run the multi-speaker TTS model with the target speaker ID:
```
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
```
```
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
```
- Run your own multi-speaker TTS model:
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
### Voice Conversion Models
```
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
```
<!-- end-tts-readme -->
## Directory Structure
```

View File

@ -5,12 +5,12 @@
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/config.json",
"https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/vocab.json"
],
"default_vocoder": null,
"commit": "e9a1953e",
"commit": "e5140314",
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": true
@ -728,6 +728,18 @@
"license": "Apache 2.0"
}
}
},
"be": {
"common-voice": {
"glow-tts":{
"description": "Belarusian GlowTTS model created by @alex73 (Github).",
"github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip",
"default_vocoder": "vocoder_models/be/common-voice/hifigan",
"commit": "c0aabb85",
"license": "CC-BY-SA 4.0",
"contact": "alex73mail@gmail.com"
}
}
}
},
"vocoder_models": {
@ -879,6 +891,17 @@
"commit": null
}
}
},
"be": {
"common-voice": {
"hifigan": {
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip",
"description": "Belarusian HiFiGAN model created by @alex73 (Github).",
"author": "@alex73",
"license": "CC-BY-SA 4.0",
"commit": "c0aabb85"
}
}
}
},
"voice_conversion_models": {

View File

@ -1 +1 @@
0.17.2
0.17.8

View File

@ -17,7 +17,7 @@ class TTS(nn.Module):
def __init__(
self,
model_name: str = None,
model_name: str = "",
model_path: str = None,
config_path: str = None,
vocoder_path: str = None,
@ -105,13 +105,14 @@ class TTS(nn.Module):
@property
def is_multi_lingual(self):
# TODO: fix this
if "xtts" in self.model_name:
# Not sure what sets this to None, but applied a fix to prevent crashing.
if isinstance(self.model_name, str) and "xtts" in self.model_name:
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
return self.synthesizer.tts_model.language_manager.num_languages > 1
return False
@property
def speakers(self):
if not self.is_multi_speaker:

View File

@ -8,9 +8,120 @@ from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
description = """
Synthesize speech on command line.
You can either use your trained model or choose a model from the provided list.
If you don't specify any models, then it uses LJSpeech based English model.
#### Single Speaker Models
- List provided models:
```
$ tts --list_models
```
- Get model info (for both tts_models and vocoder_models):
- Query by type/name:
The model_info_by_name uses the name as it from the --list_models.
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
For example:
```
$ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
$ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
```
- Query by type/idx:
The model_query_idx uses the corresponding idx from --list_models.
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
For example:
```
$ tts --model_info_by_idx tts_models/3
```
- Query info for model info by full name:
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
- Run TTS with default models:
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav
```
- Run a TTS model with its default vocoder model:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
For example:
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
```
- Run with specific TTS and vocoder models from the list:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
```
For example:
```
$ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
```
- Run your own TTS model (Using Griffin-Lim Vocoder):
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
```
- Run your own TTS and Vocoder models:
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
```
#### Multi-speaker Models
- List the available speakers and choose a <speaker_id> among them:
```
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
```
- Run the multi-speaker TTS model with the target speaker ID:
```
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
```
- Run your own multi-speaker TTS model:
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
### Voice Conversion Models
```
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
```
"""
def str2bool(v):
@ -24,92 +135,6 @@ def str2bool(v):
def main():
description = """Synthesize speech on command line.
You can either use your trained model or choose a model from the provided list.
If you don't specify any models, then it uses LJSpeech based English model.
## Example Runs
### Single Speaker Models
- List provided models:
```
$ tts --list_models
```
- Query info for model info by idx:
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
- Query info for model info by full name:
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
- Run TTS with default models:
```
$ tts --text "Text for TTS"
```
- Run a TTS model with its default vocoder model:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>
```
- Run with specific TTS and vocoder models from the list:
```
$ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --output_path
```
- Run your own TTS model (Using Griffin-Lim Vocoder):
```
$ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
```
- Run your own TTS and Vocoder models:
```
$ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
```
### Multi-speaker Models
- List the available speakers and choose as <speaker_id> among them:
```
$ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
```
- Run the multi-speaker TTS model with the target speaker ID:
```
$ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
```
- Run your own multi-speaker TTS model:
```
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
```
### Voice Conversion Models
```
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
```
"""
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
# documentation in sync more easily.
parser = argparse.ArgumentParser(
description=description.replace(" ```\n", ""),
formatter_class=RawTextHelpFormatter,
@ -310,6 +335,11 @@ If you don't specify any models, then it uses LJSpeech based English model.
if not any(check_args):
parser.parse_args(["-h"])
# Late-import to make things load faster
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
# load model manager
path = Path(__file__).parent / "../.models.json"
manager = ModelManager(path, progress_bar=args.progress_bar)

View File

@ -1085,31 +1085,6 @@ class GaussianDiffusion:
}
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.

View File

@ -5,9 +5,13 @@ from tokenizers import Tokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
)
class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, vocab_str=None):
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)

View File

@ -1170,31 +1170,6 @@ class GaussianDiffusion:
}
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
they are committed to maintain backwards compatibility.
"""
if schedule_name == "linear":
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.

View File

@ -172,7 +172,7 @@ class GPT(nn.Module):
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
def init_gpt_for_inference(self, kv_cache=True):
def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
@ -195,6 +195,17 @@ class GPT(nn.Module):
)
self.gpt.wte = self.mel_embedding
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
dtype=torch.float32, # desired data type of output
replace_method="auto", # Lets DS autmatically identify the layer to replace
replace_with_kernel_inject=True, # replace the model with the kernel injector
)
self.gpt_inference = self.ds_engine.module.eval()
def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
@ -543,3 +554,14 @@ class GPT(nn.Module):
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]
def get_generator(self, fake_inputs, **hf_generate_kwargs):
return self.gpt_inference.generate_stream(
fake_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
do_stream=True,
**hf_generate_kwargs,
)

View File

@ -1,658 +0,0 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class GPT2InferenceModel(GPT2PreTrainedModel):
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
super().__init__(config)
self.transformer = gpt
self.pos_embedding = pos_emb
self.embeddings = embeddings
self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
def store_prefix_emb(self, prefix_emb):
self.cached_prefix_emb = prefix_emb
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_prefix_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
# Create embedding
prefix_len = self.cached_prefix_emb.shape[1]
if input_ids.shape[1] != 1:
gen_inputs = input_ids[:, prefix_len:]
gen_emb = self.embeddings(gen_inputs)
gen_emb = gen_emb + self.pos_embedding(gen_emb)
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
prefix_emb = self.cached_prefix_emb.repeat_interleave(
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
)
else:
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
emb = torch.cat([prefix_emb, gen_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
super().__init__()
self.emb = nn.Embedding(seq_len, model_channels)
nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
self.relative = relative
def forward(self, x):
seq_len = x.shape[1]
if self.relative:
start = torch.randint(seq_len, (1,), device=x.device).item()
positions = torch.arange(start, start + seq_len, device=x.device)
else:
positions = torch.arange(seq_len, device=x.device)
return self.emb(positions)
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
"""
Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
Args:
layers (int): Number of layers in the GPT-2 model.
model_channels (int): Dimension of the GPT-2 model.
heads (int): Number of heads in the GPT-2 model.
max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
max_text_seq_len (int): Maximum sequence length for the text.
max_prompt_len (int): Maximum length of the prompt.
checkpointing (bool): Whether to use gradient checkpointing.
Returns:
gpt (GPT2Model): GPT-2 model.
mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
"""
gpt_config = GPT2Config(
vocab_size=123,
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_embd=model_channels,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing,
)
gpt = GPT2Model(gpt_config)
del gpt.wpe
del gpt.wte
gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
audio_pos_emb = (
LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_channels)
)
text_pos_emb = (
LearnedPositionEmbeddings(max_text_seq_len, model_channels)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_channels)
)
return gpt, audio_pos_emb, text_pos_emb
class XTTSGPTEncoder(nn.Module):
"""XTTS GPT Encoder model implementation.
Args:
start_text_token (int): Index of the start token in the text vocabulary.
stop_text_token (int): Index of the stop token in the text vocabulary.
n_layers (int): Number of layers in the GPT-2 model.
n_model_channels (int): Dimension of the GPT-2 model.
n_heads (int): Number of heads in the GPT-2 model.
max_text_tokens (int): Maximum number of text tokens.
max_audio_tokens (int): Maximum number of audio tokens.
max_prompt_tokens (int): Maximum number of prompt tokens.
audio_len_compression (int): Compression factor for the audio length.
number_text_tokens (int): Number of text tokens.
number_audio_codes (int): Number of audio codes.
start_mel_token (int): Index of the start token in the mel code vocabulary.
stop_mel_token (int): Index of the stop token in the mel code vocabulary.
checkpointing (bool): Whether or not to use gradient checkpointing at training.
"""
_inference_flag = False
def __init__(
self,
start_text_token=261,
stop_text_token=0,
n_layers=8,
n_model_channels=512,
n_heads=8,
max_text_tokens=120,
max_audio_tokens=250,
max_prompt_tokens=70,
audio_len_compression=1024,
number_text_tokens=256,
number_audio_codes=8194,
start_mel_token=8192,
stop_mel_token=8193,
checkpointing=True,
label_smoothing=0.0,
):
super().__init__()
self.label_smoothing = label_smoothing
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_audio_codes = number_audio_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.start_prompt_token = start_mel_token
self.stop_prompt_token = stop_mel_token
self.n_layers = n_layers
self.n_heads = n_heads
self.n_model_channels = n_model_channels
self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
self.max_prompt_tokens = max_prompt_tokens
self.audio_len_compression = audio_len_compression
# embedding layers
self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
# initialize the GPT-2 model
(
self.gpt,
self.audio_pos_embedding,
self.text_pos_embedding,
) = init_gpt(
n_layers,
n_model_channels,
n_heads,
self.max_audio_tokens,
self.max_text_tokens,
self.max_prompt_tokens,
checkpointing,
)
# output layers
self.final_norm = nn.LayerNorm(n_model_channels)
self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
def get_grad_norm_parameter_groups(self):
return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
"gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
self._inference_flag = True
seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
gpt_config = GPT2Config(
vocab_size=self.max_audio_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.n_model_channels,
n_layer=self.n_layers,
n_head=self.n_heads,
gradient_checkpointing=False,
use_cache=True,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.audio_pos_embedding,
self.audio_embedding,
self.final_norm,
self.mel_head,
kv_cache=kv_cache,
)
self.gpt.wte = self.audio_embedding
def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
for b in range(len(audio_token_lens)):
actual_end = audio_token_lens[b]
if actual_end < audio_tokens.shape[-1]:
audio_tokens[b, actual_end:] = self.stop_mel_token
return audio_tokens
def get_logits(
self,
speech_conditioning_inputs,
first_inputs,
first_head,
second_inputs=None,
second_head=None,
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_text=None,
attn_mask_mel=None,
):
if prompt is not None and speech_conditioning_inputs is not None:
offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
if second_inputs is not None:
emb = torch.cat(
[speech_conditioning_inputs, prompt, first_inputs, second_inputs],
dim=1,
)
else:
emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
elif speech_conditioning_inputs is not None:
offset = speech_conditioning_inputs.shape[1]
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
elif prompt is not None:
offset = prompt.shape[1]
if second_inputs is not None:
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([prompt, first_inputs], dim=1)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attn_mask = None
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
return_dict=True,
output_attentions=get_attns,
attention_mask=attn_mask,
)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
if return_latent:
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
first_logits = enc[:, : first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1] :]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def get_prompts(self, prompt_codes):
prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
return prompt
def forward(
self,
text_inputs,
text_lengths,
audio_codes,
wav_lengths,
prompt_codes,
return_attentions=False,
return_latent=False,
):
max_text_len = text_lengths.max()
# Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
# Like [..., 186, 45, 45, 83] where actually it should end with 186.
# We take last 3 codes to prevent abrupt ending of the audio.
# TODO: This is might need some testing.
mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = mel_lengths.max()
if max_mel_len > audio_codes.shape[-1]:
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
# silence aware lengths, skip the silence tokens at the end of the mel codes.
silence = True
for idx, l in enumerate(mel_lengths):
length = l.item()
while silence:
if audio_codes[idx, length - 1] != 83:
break
length -= 1
mel_lengths[idx] = length
# Lovely assertions
assert (
max_mel_len <= audio_codes.shape[-1]
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
assert (
max_text_len <= text_inputs.shape[-1]
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
# Append stop token to text inputs
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
# Append silence token to mel codes
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
# Pad mel codes with STOP_MEL_TOKEN
audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
# Compute speech conditioning input
conds = None
if speech_conditioning_input is not None:
if not return_latent:
# Compute speech conditioning input
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
else:
# already computed
conds = speech_conditioning_input.unsqueeze(1)
# Build input and target tensors
# Prepend start token to inputs and append stop token to targets
text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
# Set attn_mask
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_mel = torch.ones(
audio_codes.shape[0],
audio_codes.shape[1],
dtype=torch.bool,
device=audio_codes.device,
)
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
for idx, l in enumerate(mel_lengths):
attn_mask_mel[idx, l + 1 :] = 0.0
# Compute text embeddings + positional embeddings
# print(" > text input latent:", text_inputs)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
# Compute mel embeddings + positional embeddings
audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
# Compute prompt embeddings + positional embeddings
prompt = self.get_prompts(prompt_codes)
# prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
# dropout prompt embeddings
prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
# Get logits
sub = -4 # don't ask me why 😄
if self.training:
sub = -1
_, audio_logits = self.get_logits(
conds,
text_emb,
self.text_head,
audio_emb,
self.mel_head,
prompt=prompt_emb,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
return audio_logits[:, :sub] # sub to prevent bla.
def compute_embeddings(
self,
speech_conditioning_latent,
text_inputs,
input_tokens=None,
prompt_codes=None,
pad_input_text=False,
):
"""Compute all the embeddings needed for inference."""
if pad_input_text and text_inputs.shape[1] < 250:
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
else:
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
print(" > Text inputs:", text_inputs)
if prompt_codes is not None:
prompt_codes = self.get_prompts(prompt_codes)
# prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
print(" > Prompt inputs:", prompt_codes)
print(" > Prompt inputs shape:", prompt_codes.shape)
emb = torch.cat([prompt_emb, emb], dim=1)
if speech_conditioning_latent is not None:
conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, emb], dim=1)
self.inference_model.store_prefix_emb(emb)
fake_inputs = torch.full(
(
emb.shape[0],
emb.shape[1] + 1, # +1 for the start_mel_token
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
fake_inputs[:, -1] = self.start_mel_token
if input_tokens is not None:
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
return fake_inputs
def inference(
self,
text_inputs,
input_tokens=None,
prompt_codes=None,
pad_input_text=False,
**hf_generate_kwargs,
):
if pad_input_text and text_inputs.shape[1] < 250:
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
else:
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
if prompt_codes is not None:
prompt_codes = self.get_prompts(prompt_codes)
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
emb = torch.cat([prompt_emb, emb], dim=1)
self.inference_model.store_prefix_emb(emb)
fake_inputs = torch.full(
(
emb.shape[0],
emb.shape[1] + 1, # +1 for the start_mel_token
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
fake_inputs[:, -1] = self.start_mel_token
if input_tokens is not None:
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
gen = self.inference_model.generate(
fake_inputs,
bos_token_id=self.start_mel_token,
pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, fake_inputs.shape[1] :], gen
return gen[:, fake_inputs.shape[1] :]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,742 @@
import torch
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
def get_padding(k, d):
return int((k * d - d) / 2)
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|--------------------------------------------------------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input tensor.
Returns:
Tensor: output tensor.
Shapes:
x: [B, C, T]
"""
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
"""Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
Network::
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|---------------------------------------------------|
Args:
channels (int): number of hidden channels for the convolutional layers.
kernel_size (int): size of the convolution filter in each layer.
dilations (list): list of dilation value for each conv layer in a block.
"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class HifiganGenerator(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
resblock_type,
resblock_dilation_sizes,
resblock_kernel_sizes,
upsample_kernel_sizes,
upsample_initial_channel,
upsample_factors,
inference_padding=5,
cond_channels=0,
conv_pre_weight_norm=True,
conv_post_weight_norm=True,
conv_post_bias=True,
cond_in_each_up_layer=False,
):
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
Network:
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
.. -> zI ---|
resblockN_kNx1 -> zN ---'
Args:
in_channels (int): number of input tensor channels.
out_channels (int): number of output tensor channels.
resblock_type (str): type of the `ResBlock`. '1' or '2'.
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
for each consecutive upsampling layer.
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
"""
super().__init__()
self.inference_padding = inference_padding
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
self.cond_in_each_up_layer = cond_in_each_up_layer
# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
# MRF blocks
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
if not conv_pre_weight_norm:
remove_weight_norm(self.conv_pre)
if not conv_post_weight_norm:
remove_weight_norm(self.conv_post)
if self.cond_in_each_up_layer:
self.conds = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
self.conds.append(nn.Conv1d(cond_channels, ch, 1))
def forward(self, x, g=None):
"""
Args:
x (Tensor): feature input tensor.
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
o = self.conv_pre(x)
if hasattr(self, "cond_layer"):
o = o + self.cond_layer(g)
for i in range(self.num_upsamples):
o = F.leaky_relu(o, LRELU_SLOPE)
o = self.ups[i](o)
if self.cond_in_each_up_layer:
o = o + self.conds[i](g)
z_sum = None
for j in range(self.num_kernels):
if z_sum is None:
z_sum = self.resblocks[i * self.num_kernels + j](o)
else:
z_sum += self.resblocks[i * self.num_kernels + j](o)
o = z_sum / self.num_kernels
o = F.leaky_relu(o)
o = self.conv_post(o)
o = torch.tanh(o)
return o
@torch.no_grad()
def inference(self, c):
"""
Args:
x (Tensor): conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
return self.forward(c)
def remove_weight_norm(self):
print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class SEBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
super(SEBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.se = SELayer(planes, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def set_init_dict(model_dict, checkpoint_state, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint_state.items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers
if c.has("reinit_layers") and c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
return model_dict
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
# pylint: disable=W0102
def __init__(
self,
input_dim=64,
proj_dim=512,
layers=[3, 4, 6, 3],
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
use_torch_spec=False,
audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
self.inplanes = num_filters[0]
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError("Undefined encoder")
self.fc = nn.Linear(out_dim, proj_dim)
self._init_layers()
def _init_layers(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def create_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
# pylint: disable=R0201
def new_parameter(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
if self.encoder_type == "SAP":
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
x = self.fc(x)
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x
def load_checkpoint(
self,
checkpoint_path: str,
eval: bool = False,
use_cuda: bool = False,
criterion=None,
cache=False,
):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try:
self.load_state_dict(state["model"])
print(" > Model fully restored. ")
except (KeyError, RuntimeError) as error:
# If eval raise the error
if eval:
raise error
print(" > Partial model initialization.")
model_dict = self.state_dict()
model_dict = set_init_dict(model_dict, state["model"])
self.load_state_dict(model_dict)
del model_dict
# load the criterion for restore_path
if criterion is not None and "criterion" in state:
try:
criterion.load_state_dict(state["criterion"])
except (KeyError, RuntimeError) as error:
print(" > Criterion load ignored because of:", error)
if use_cuda:
self.cuda()
if criterion is not None:
criterion = criterion.cuda()
if eval:
self.eval()
assert not self.training
if not eval:
return criterion, state["step"]
return criterion
class HifiDecoder(torch.nn.Module):
def __init__(
self,
input_sample_rate=22050,
output_sample_rate=24000,
output_hop_length=256,
ar_mel_length_compression=1024,
decoder_input_dim=1024,
resblock_type_decoder="1",
resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
resblock_kernel_sizes_decoder=[3, 7, 11],
upsample_rates_decoder=[8, 8, 2, 2],
upsample_initial_channel_decoder=512,
upsample_kernel_sizes_decoder=[16, 16, 4, 4],
d_vector_dim=512,
cond_d_vector_in_each_upsampling_layer=True,
speaker_encoder_audio_config={
"fft_size": 512,
"win_length": 400,
"hop_length": 160,
"sample_rate": 16000,
"preemphasis": 0.97,
"num_mels": 64,
},
):
super().__init__()
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.output_hop_length = output_hop_length
self.ar_mel_length_compression = ar_mel_length_compression
self.speaker_encoder_audio_config = speaker_encoder_audio_config
self.waveform_decoder = HifiganGenerator(
decoder_input_dim,
1,
resblock_type_decoder,
resblock_dilation_sizes_decoder,
resblock_kernel_sizes_decoder,
upsample_kernel_sizes_decoder,
upsample_initial_channel_decoder,
upsample_rates_decoder,
inference_padding=0,
cond_channels=d_vector_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
)
self.speaker_encoder = ResNetSpeakerEncoder(
input_dim=64,
proj_dim=512,
log_input=True,
use_torch_spec=True,
audio_config=speaker_encoder_audio_config,
)
@property
def device(self):
return next(self.parameters()).device
def forward(self, latents, g=None):
"""
Args:
x (Tensor): feature input tensor (GPT latent).
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
z = torch.nn.functional.interpolate(
latents.transpose(1, 2),
scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
mode="linear",
).squeeze(1)
# upsample to the right sr
if self.output_sample_rate != self.input_sample_rate:
z = torch.nn.functional.interpolate(
z,
scale_factor=[self.output_sample_rate / self.input_sample_rate],
mode="linear",
).squeeze(0)
o = self.waveform_decoder(z, g=g)
return o
@torch.no_grad()
def inference(self, c, g):
"""
Args:
x (Tensor): feature input tensor (GPT latent).
g (Tensor): global conditioning input tensor.
Returns:
Tensor: output waveform.
Shapes:
x: [B, C, T]
Tensor: [B, 1, T]
"""
return self.forward(c, g=g)
def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]
states_keys = list(state.keys())
for key in states_keys:
if "waveform_decoder." not in key and "speaker_encoder." not in key:
del state[key]
self.load_state_dict(state)
if eval:
self.eval()
assert not self.training
self.waveform_decoder.remove_weight_norm()

File diff suppressed because it is too large Load Diff

View File

@ -171,17 +171,6 @@ def multilingual_cleaners(text, lang):
return text
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', "")
return text
def remove_extraneous_punctuation(word):
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "": "-", "": "-", "`": "'", "ʼ": "'"}
replace = re.compile(
@ -195,32 +184,6 @@ def remove_extraneous_punctuation(word):
return word
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
_whitespace_re = re.compile(r"\s+")
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def arabic_cleaners(text):
text = lowercase(text)
text = collapse_whitespace(text)
@ -261,7 +224,10 @@ class VoiceBpeTokenizer:
txt = " ".join([result["kana"] for result in results])
txt = basic_cleaners(txt)
elif lang == "en":
if txt[:4] == "[en]":
txt = txt[4:]
txt = english_cleaners(txt)
txt = "[en]" + txt
elif lang == "ar":
txt = arabic_cleaners(txt)
elif lang == "zh-cn":

View File

@ -726,8 +726,8 @@ class DelightfulTTS(BaseTTSE2E):
def pitch_std(self):
return self.acoustic_model.pitch_std
@pitch_mean.setter
def pitch_std(self, value): # pylint: disable=function-redefined
@pitch_std.setter
def pitch_std(self, value):
self.acoustic_model.pitch_std = value
@property
@ -1518,10 +1518,6 @@ class DelightfulTTS(BaseTTSE2E):
scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler_D, scheduler_G]
def on_train_step_start(self, trainer):
"""Schedule binary loss weight."""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
def on_epoch_end(self, trainer): # pylint: disable=unused-argument
# stop updating mean and var
# TODO: do the same for F0
@ -1578,6 +1574,7 @@ class DelightfulTTS(BaseTTSE2E):
Args:
trainer (Trainer): Trainer object.
"""
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
self.train_disc = ( # pylint: disable=attribute-defined-outside-init
trainer.total_steps_done >= self.config.steps_to_start_discriminator
)

View File

@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
init_stream_support()
def load_audio(audiopath, sr=22050):
"""
@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
Args:
gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
For GPT model:
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -231,12 +233,12 @@ class XttsArgs(Coqpit):
gpt_batch_size: int = 1
enable_redaction: bool = False
lazy_load: bool = True
kv_cache: bool = True
gpt_checkpoint: str = None
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255
use_hifigan: bool = True
# XTTS GPT Encoder params
tokenizer_file: str = ""
@ -266,6 +268,15 @@ class XttsArgs(Coqpit):
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# HifiGAN Decoder params
input_sample_rate: int = 22050
output_sample_rate: int = 24000
output_hop_length: int = 256
ar_mel_length_compression: int = 1024
decoder_input_dim: int = 1024
d_vector_dim: int = 512
cond_d_vector_in_each_upsampling_layer: bool = True
# constants
duration_const: int = 102400
@ -285,7 +296,6 @@ class Xtts(BaseTTS):
def __init__(self, config: Coqpit):
super().__init__(config, ap=None, tokenizer=None)
self.lazy_load = self.args.lazy_load
self.mel_stats_path = None
self.config = config
self.gpt_checkpoint = self.args.gpt_checkpoint
@ -295,7 +305,6 @@ class Xtts(BaseTTS):
self.tokenizer = VoiceBpeTokenizer()
self.gpt = None
self.diffusion_decoder = None
self.init_models()
self.register_buffer("mel_stats", torch.ones(80))
@ -322,40 +331,39 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token,
)
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator()
if self.args.use_hifigan:
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
else:
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator()
@property
def device(self):
return next(self.parameters()).device
@contextmanager
def lazy_load_model(self, model):
"""Context to load a model on demand.
Args:
model (nn.Module): The model to be loaded.
"""
if self.lazy_load:
yield model
else:
m = model.to(self.device)
yield m
m = model.cpu()
@torch.inference_mode()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
@ -370,6 +378,7 @@ class Xtts(BaseTTS):
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
return cond_latent.transpose(1, 2)
@torch.inference_mode()
def get_diffusion_cond_latents(
self,
audio_path,
@ -389,20 +398,33 @@ class Xtts(BaseTTS):
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
return diffusion_latent
@torch.inference_mode()
def get_speaker_embedding(
self,
audio_path
):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
audio.to(self.device), l2_norm=True
).unsqueeze(-1).to(self.device)
return speaker_embedding
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio_path)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
diffusion_cond_latents = self.get_diffusion_cond_latents(
audio_path,
)
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.
@ -447,10 +469,10 @@ class Xtts(BaseTTS):
"decoder_sampler": config.decoder_sampler,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, ref_audio_path, language, **settings)
return self.full_inference(text, ref_audio_path, language, **settings)
@torch.no_grad()
def inference(
@torch.inference_mode()
def full_inference(
self,
text,
ref_audio_path,
@ -525,6 +547,54 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
return self.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
decoder_iterations=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler,
**hf_generate_kwargs,
)
@torch.inference_mode()
def inference(
self,
text,
language,
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -532,74 +602,147 @@ class Xtts(BaseTTS):
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
(
gpt_cond_latent,
diffusion_conditioning,
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=decoder_sampler,
)
if not self.args.use_hifigan:
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=decoder_sampler,
)
with torch.no_grad():
self.gpt = self.gpt.to(self.device)
with self.lazy_load_model(self.gpt) as gpt:
gpt_codes = gpt.generate(
cond_latents=gpt_cond_latent,
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
gpt_codes = self.gpt.generate(
cond_latents=gpt_cond_latent,
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
text_len,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
with self.lazy_load_model(self.gpt) as gpt:
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = gpt(
text_tokens,
text_len,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
if self.args.use_hifigan:
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
else:
mel = do_spectrogram_diffusion(
diffusion,
self.diffusion_decoder,
diffuser,
gpt_latents,
diffusion_conditioning,
temperature=diffusion_temperature,
)
with self.lazy_load_model(self.vocoder) as vocoder:
wav = vocoder.inference(mel)
wav = self.vocoder.inference(mel)
return {"wav": wav.cpu().numpy().squeeze()}
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
"""Handle chunk formatting in streaming mode"""
wav_chunk = wav_gen[:-overlap_len]
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
if wav_overlap is not None:
crossfade_wav = wav_chunk[:overlap_len]
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav
wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
@torch.inference_mode()
def inference_stream(
self,
text,
language,
gpt_cond_latent,
speaker_embedding,
# Streaming
stream_chunk_size=20,
overlap_wav_len=1024,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
do_sample=True,
# Decoder inference
**hf_generate_kwargs,
):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False
while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
def forward(self):
raise NotImplementedError("XTTS Training is not implemented")
@ -616,7 +759,14 @@ class Xtts(BaseTTS):
super().eval()
def load_checkpoint(
self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
use_deepspeed=False,
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -626,7 +776,7 @@ class Xtts(BaseTTS):
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
Returns:
@ -636,19 +786,26 @@ class Xtts(BaseTTS):
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
if os.path.exists(vocab_path):
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
self.load_state_dict(checkpoint, strict=strict)
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
self.gpt.eval()
self.diffusion_decoder.eval()
self.vocoder.eval()
def train_step(self):
raise NotImplementedError("XTTS Training is not implemented")

View File

@ -6,6 +6,7 @@ from pathlib import Path
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple
import fsspec
import requests
from tqdm import tqdm
@ -315,11 +316,51 @@ class ModelManager(object):
"""Check if the user has agreed to the terms of service"""
if "tos_required" in model_item and model_item["tos_required"]:
tos_path = os.path.join(model_full_path, "tos_agreed.txt")
if os.path.exists(tos_path):
if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
return True
return False
return True
def create_dir_and_download_model(self, model_name, model_item, output_path):
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
def check_if_configs_are_equal(self, model_name, model_item, output_path):
with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
config_local = json.load(f)
remote_url = None
for url in model_item["hf_url"]:
if "config.json" in url:
remote_url = url
break
with fsspec.open(remote_url, "r", encoding="utf-8") as f:
config_remote = json.load(f)
if not config_local == config_remote:
print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
self.create_dir_and_download_model(model_name, model_item, output_path)
else:
print(f" > {model_name} is already downloaded.")
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
@ -338,28 +379,18 @@ class ModelManager(object):
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.")
# if the configs are different, redownload it
# ToDo: we need a better way to handle it
if "xtts_v1" in model_name:
try:
self.check_if_configs_are_equal(model_name, model_item, output_path)
except:
pass
else:
print(f" > {model_name} is already downloaded.")
else:
os.makedirs(output_path, exist_ok=True)
# handle TOS
if not self.tos_agreed(model_item, output_path):
if not self.ask_tos(output_path):
os.rmdir(output_path)
raise Exception(" [!] You must agree to the terms of service to use this model.")
print(f" > Downloading model to {output_path}")
try:
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
elif "github_rls_url" in model_item:
self._download_github_model(model_item, output_path)
elif "hf_url" in model_item:
self._download_hf_model(model_item, output_path)
self.create_dir_and_download_model(model_name, model_item, output_path)
except requests.RequestException as e:
print(f" > Failed to download the model file to {output_path}")
rmtree(output_path)
raise e
self.print_model_license(model_item=model_item)
# find downloaded files
output_model_path = output_path
output_config_path = None
@ -498,13 +529,16 @@ class ModelManager(object):
print(f" > Error: Bad zip file - {file_url}")
raise zipfile.BadZipFile # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in z.namelist()[1:]:
for file_path in z.namelist():
src_path = os.path.join(output_folder, file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0]))
if os.path.isfile(src_path):
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove redundant (hidden or not) folders
for file_path in z.namelist():
if os.path.isdir(os.path.join(output_folder, file_path)):
rmtree(os.path.join(output_folder, file_path))
@staticmethod
def _download_tar_file(file_url, output_folder, progress_bar):

View File

@ -116,12 +116,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x

View File

@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
You can also mail us at info@coqui.ai.
Using 🐸TTS API:
### Inference
#### 🐸TTS API
```python
from TTS.api import TTS
@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en")
# generate speech by cloning a voice using custom settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en",
decoder_iterations=30)
```
Using 🐸TTS Command line:
#### 🐸TTS Command line
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
@ -58,25 +52,85 @@ Using 🐸TTS Command line:
--use_cuda true
```
Using model directly:
#### model directly
If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
```console
pip install deepspeed==0.8.3
```
```python
import os
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
print("Loading model...")
config = XttsConfig()
config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
model.cuda()
outputs = model.synthesize(
print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
print("Inference...")
out = model.inference(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
config,
speaker_wav="/data/TTS-public/_refclips/3.wav",
gpt_cond_len=3,
language="en",
"en",
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here
)
torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
```
#### streaming inference
Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
```python
import os
import time
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
print("Loading model...")
config = XttsConfig()
config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
print("Inference...")
t0 = time.time()
chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
if i == 0:
print(f"Time to first chunck: {time.time() - t0}")
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

View File

@ -13,15 +13,15 @@
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n",
"import numpy as np\n",
"from tqdm import tqdm as tqdm\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import pickle\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
@ -33,8 +33,8 @@
"\n",
"%matplotlib inline\n",
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
"# Configure CUDA visibility\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
]
},
{
@ -43,6 +43,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Function to create directories and file names\n",
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
" file_name = wav_file.split('.')[0]\n",
@ -61,6 +62,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"DATASET = \"ljspeech\"\n",
@ -73,12 +75,15 @@
"QUANTIZE_BIT = None\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n",
"# Check CUDA availability\n",
"use_cuda = torch.cuda.is_available()\n",
"print(\" > CUDA enabled: \", use_cuda)\n",
"\n",
"# Load the configuration\n",
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
"print(C['r'])"
]
},
{
@ -87,14 +92,13 @@
"metadata": {},
"outputs": [],
"source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n",
"# If the vocabulary was passed, replace the default\n",
"if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n",
"# load the model\n",
"# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speaker\n",
"# TODO: multiple speakers\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
]
@ -105,11 +109,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Load the preprocessor based on the dataset\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"dataset = TTSDataset(\n",
" checkpoint[\"config\"][\"r\"],\n",
" C,\n",
" C.text_cleaner,\n",
" False,\n",
" ap,\n",
@ -124,6 +129,24 @@
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -137,83 +160,85 @@
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"# Start processing with a progress bar\n",
"with torch.no_grad():\n",
" for data in tqdm(loader):\n",
" # setup input data\n",
" text_input = data[0]\n",
" text_lengths = data[1]\n",
" linear_input = data[3]\n",
" mel_input = data[4]\n",
" mel_lengths = data[5]\n",
" stop_targets = data[6]\n",
" item_idx = data[7]\n",
" for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n",
" if use_cuda:\n",
" text_input = text_input.cuda()\n",
" text_lengths = text_lengths.cuda()\n",
" mel_input = mel_input.cuda()\n",
" mel_lengths = mel_lengths.cuda()\n",
" # dispatch data to GPU\n",
" if use_cuda:\n",
" text_input = text_input.cuda()\n",
" text_lengths = text_lengths.cuda()\n",
" mel_input = mel_input.cuda()\n",
" mel_lengths = mel_lengths.cuda()\n",
"\n",
" mask = sequence_mask(text_lengths)\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
" \n",
" # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
" mask = sequence_mask(text_lengths)\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
"\n",
" # compute mel specs from linear spec if model is Tacotron\n",
" if C.model == \"Tacotron\":\n",
" mel_specs = []\n",
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
" for b in range(postnet_outputs.shape[0]):\n",
" postnet_output = postnet_outputs[b]\n",
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
" postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n",
" # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
"\n",
" if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n",
" wav_file_path = item_idx[idx]\n",
" wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
" # compute mel specs from linear spec if the model is Tacotron\n",
" if C.model == \"Tacotron\":\n",
" mel_specs = []\n",
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
" for b in range(postnet_outputs.shape[0]):\n",
" postnet_output = postnet_outputs[b]\n",
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
" postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n",
"\n",
" # quantize and save wav\n",
" if QUANTIZED_WAV:\n",
" wavq = ap.quantize(wav)\n",
" np.save(wavq_path, wavq)\n",
" if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n",
" wav_file_path = item_idx[idx]\n",
" wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
"\n",
" # save TTS mel\n",
" mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n",
" mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n",
" # quantize and save wav\n",
" if QUANTIZED_WAV:\n",
" wavq = ap.quantize(wav)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" metadata.append([wav_file_path, mel_path])\n",
" # save TTS mel\n",
" mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n",
" mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n",
"\n",
" # for wavernn\n",
" if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
" \n",
" # for pwgan\n",
" with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
" metadata.append([wav_file_path, mel_path])\n",
"\n",
" print(np.mean(losses))\n",
" print(np.mean(postnet_losses))"
" except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n",
" # Calculate and log mean losses\n",
" mean_loss = np.mean(losses)\n",
" mean_postnet_loss = np.mean(postnet_losses)\n",
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n",
"if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n",
"# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
"\n",
"# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
]
},
{

View File

@ -1,5 +1,11 @@
[build-system]
requires = ["setuptools", "wheel", "cython==0.29.30", "numpy==1.22.0", "packaging"]
requires = [
"setuptools",
"wheel",
"cython~=0.29.30",
"numpy>=1.22.0",
"packaging",
]
[flake8]
max-line-length=120
@ -7,25 +13,6 @@ max-line-length=120
[tool.black]
line-length = 120
target-version = ['py39']
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| foo.py # also separately exclude a file named foo.py in
# the root of the project
)
'''
[tool.isort]
line_length = 120

View File

@ -5,26 +5,27 @@ cython==0.29.30
scipy>=1.11.2
torch>=1.7
torchaudio
soundfile
librosa==0.10.0.*
soundfile==0.12.*
librosa==0.10.*
scikit-learn==1.3.0
numba==0.55.1;python_version<"3.9"
numba==0.57.0;python_version>="3.9"
inflect==5.6.0
tqdm
anyascii
pyyaml
fsspec>=2021.04.0
aiohttp
packaging
inflect==5.6.*
tqdm==4.64.*
anyascii==0.3.*
pyyaml==6.*
fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail
aiohttp==3.8.*
packaging==23.1
# deps for examples
flask
flask==2.*
# deps for inference
pysbd
pysbd==0.3.4
# deps for notebooks
umap-learn==0.5.1
pandas
umap-learn==0.5.*
pandas>=1.4,<2.0
# deps for training
matplotlib
matplotlib==3.7.*
# coqui stack
trainer
# config management
@ -39,14 +40,14 @@ jamo
nltk
g2pkk>=0.1.1
# deps for bangla
bangla==0.0.2
bangla
bnnumerizer
bnunicodenormalizer==0.1.1
bnunicodenormalizer
#deps for tortoise
k_diffusion
einops
transformers
einops==0.6.*
transformers==4.33.*
#deps for bark
encodec
encodec==0.1.*
# deps for XTTS
unidecode
unidecode==1.3.*

32
scripts/sync_readme.py Normal file
View File

@ -0,0 +1,32 @@
import argparse
from pathlib import Path
def replace_between_markers(content, marker: str, replacement: str) -> str:
start_marker = f"<!-- begin-{marker} -->\n\n"
end_marker = f"\n\n<!-- end-{marker} -->\n"
start_index = content.index(start_marker) + len(start_marker)
end_index = content.index(end_marker)
content = content[:start_index] + replacement + content[end_index:]
return content
def sync_readme():
ap = argparse.ArgumentParser()
ap.add_argument("--check", action="store_true", default=False)
args = ap.parse_args()
readme_path = Path(__file__).parent.parent / "README.md"
orig_content = readme_path.read_text()
from TTS.bin.synthesize import description
new_content = replace_between_markers(orig_content, "tts-readme", description.strip())
if args.check:
if orig_content != new_content:
print("README.md is out of sync; please edit TTS/bin/TTS_README.md and run scripts/sync_readme.py")
exit(42)
readme_path.write_text(new_content)
print("Updated README.md")
if __name__ == "__main__":
sync_readme()

View File

@ -0,0 +1,9 @@
import subprocess
import sys
from pathlib import Path
def test_readme_up_to_date():
root = Path(__file__).parent.parent.parent
sync_readme = root / "scripts" / "sync_readme.py"
subprocess.check_call([sys.executable, str(sync_readme), "--check"], cwd=root)

View File

@ -3,13 +3,19 @@ import glob
import os
import shutil
import torch
from tests import get_tests_data_path, get_tests_output_path, run_cli
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager
MODELS_WITH_SEP_TESTS = ["bark", "xtts"]
MODELS_WITH_SEP_TESTS = [
"tts_models/multilingual/multi-dataset/bark",
"tts_models/en/multi-dataset/tortoise-v2",
"tts_models/multilingual/multi-dataset/xtts_v1",
]
def run_models(offset=0, step=1):
@ -17,7 +23,8 @@ def run_models(offset=0, step=1):
print(" > Run synthesizer with all the models.")
output_path = os.path.join(get_tests_output_path(), "output.wav")
manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS]
model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS]
print("Model names:", model_names)
for model_name in model_names[offset::step]:
print(f"\n > Run - {model_name}")
model_path, _, _ = manager.download_model(model_name)
@ -67,23 +74,83 @@ def run_models(offset=0, step=1):
def test_xtts():
"""XTTS is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
use_gpu = torch.cuda.is_available()
if use_gpu:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
)
else:
run_cli(
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
f'--speaker_wav "{speaker_wav}" --language_idx "en"'
)
def test_xtts_streaming():
"""Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path)
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
print("Inference...")
chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
if i == 0:
assert chunk.shape[-1] > 5000
wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1
def test_tortoise():
output_path = os.path.join(get_tests_output_path(), "output.wav")
use_gpu = torch.cuda.is_available()
if use_gpu:
run_cli(
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
)
else:
run_cli(
f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
)
def test_bark():
"""Bark is too big to run on github actions. We need to test it locally"""
output_path = os.path.join(get_tests_output_path(), "output.wav")
run_cli(
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
)
use_gpu = torch.cuda.is_available()
if use_gpu:
run_cli(
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
)
else:
run_cli(
f" tts --model_name tts_models/multilingual/multi-dataset/bark "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
)
def test_voice_conversion():