Merge pull request #3105 from coqui-ai/dev

v0.19.0
This commit is contained in:
Eren Gölge 2023-10-25 14:27:49 +02:00 committed by GitHub
commit 9c68992ccc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 14010 additions and 292 deletions

53
.github/workflows/xtts_tests.yml vendored Normal file
View File

@ -0,0 +1,53 @@
name: xtts-tests
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize, reopened]
jobs:
check_skip:
runs-on: ubuntu-latest
if: "! contains(github.event.head_commit.message, '[ci skip]')"
steps:
- run: echo "${{ github.event.head_commit.message }}"
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.9, "3.10", "3.11"]
experimental: [false]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64
cache: 'pip'
cache-dependency-path: 'requirements*'
- name: check OS
run: cat /etc/os-release
- name: set ENV
run: export TRAINER_TELEMETRY=0
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends git make gcc
sudo apt-get install espeak
sudo apt-get install espeak-ng
make system-deps
- name: Install/upgrade Python setup deps
run: python3 -m pip install --upgrade pip setuptools wheel
- name: Replace scarf urls
run: |
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
- name: Install TTS
run: |
python3 -m pip install .[all]
python3 setup.py egg_info
- name: Unit tests
run: make test_xtts

View File

@ -22,6 +22,9 @@ test_tts: ## run tts tests.
test_tts2: ## run tts tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2
test_xtts:
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests
test_aux: ## run aux tests.
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
./run_bash_tests.sh

View File

@ -1 +1 @@
0.18.2
0.19.0

View File

@ -197,6 +197,7 @@ class GPT(nn.Module):
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
@ -233,6 +234,7 @@ class GPT(nn.Module):
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_cond=None,
attn_mask_text=None,
attn_mask_mel=None,
):
@ -248,8 +250,11 @@ class GPT(nn.Module):
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
@ -326,7 +331,7 @@ class GPT(nn.Module):
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
def get_style_emb(self, cond_input, return_latent=False):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
@ -335,26 +340,7 @@ class GPT(nn.Module):
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
if sample:
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]
conds = self.conditioning_encoder(conds)
conds = self.conditioning_encoder(cond_input)
else:
# already computed
conds = cond_input.unsqueeze(1)
@ -366,10 +352,9 @@ class GPT(nn.Module):
text_lengths,
audio_codes,
wav_lengths,
cond_lens=None,
cond_mels=None,
cond_idxs=None,
cond_latents=None,
loss_weights=None,
return_attentions=False,
return_latent=False,
):
@ -377,11 +362,12 @@ class GPT(nn.Module):
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, 1, 80,s)
cond_idxs: cond start and end indexs, (b, 2)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
@ -393,6 +379,11 @@ class GPT(nn.Module):
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
@ -435,9 +426,16 @@ class GPT(nn.Module):
)
# Set attn_mask
attn_mask_cond = None
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_cond = torch.ones(
cond_mels.shape[0],
cond_mels.shape[-1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
@ -451,6 +449,11 @@ class GPT(nn.Module):
device=audio_codes.device,
)
if cond_idxs is not None:
for idx, r in enumerate(cond_idxs):
l = r[1] - r[0]
attn_mask_cond[idx, l:] = 0.0
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
@ -465,7 +468,7 @@ class GPT(nn.Module):
# Compute speech conditioning input
if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits
sub = -5 # don't ask me why 😄
@ -480,6 +483,7 @@ class GPT(nn.Module):
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
@ -501,6 +505,13 @@ class GPT(nn.Module):
0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
if cond_idxs is not None:
cond_start = cond_idxs[idx, 0]
cond_end = cond_idxs[idx, 1]
mel_targets[idx, cond_start:cond_end] = -1
# Compute losses
loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
@ -548,7 +559,7 @@ class GPT(nn.Module):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
@ -561,7 +572,7 @@ class GPT(nn.Module):
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
max_length=self.max_mel_tokens,
do_stream=True,
**hf_generate_kwargs,
)

View File

@ -1,13 +1,12 @@
import torch
import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
self.cond_in_each_up_layer = cond_in_each_up_layer
# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
assert not self.training
self.remove_weight_norm()
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
"""This is copied from 🐸TTS to remove it from the dependencies."""
# pylint: disable=W0102
def __init__(
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
return criterion, state["step"]
return criterion
class HifiDecoder(torch.nn.Module):
def __init__(
self,
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
"""
return self.forward(c, g=g)
def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]

View File

@ -1,26 +1,27 @@
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
import copy
import inspect
import random
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from transformers import (
BeamSearchScorer,
ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
GenerationConfig,
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
DisjunctiveConstraint,
BeamSearchScorer,
PhrasalConstraint,
ConstrainedBeamSearchScorer,
PreTrainedModel,
StoppingCriteriaList,
)
import numpy as np
import random
import warnings
import inspect
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
import torch
from typing import Callable, List, Optional, Union
from torch import nn
import torch.distributed as dist
import copy
def setup_seed(seed):
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = False,
seed=0,
**kwargs,
@ -125,7 +124,7 @@ class NewGenerationMixin(GenerationMixin):
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
#setup_seed(seed)
# setup_seed(seed)
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = StreamGenerationConfig.from_model_config(
self.config
)
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if (
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if (
model_kwargs.get("attention_mask", None) is None
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
if not self.config.is_encoder_decoder:
if (
generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
> 0
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
UserWarning,
)
elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
elif (
not has_default_max_length and generation_config.max_new_tokens is not None
):
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
if (
generation_config.min_length is not None
and generation_config.min_length > generation_config.max_length
):
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = (
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
# 7. determine generation mode
is_constraint_gen_mode = (
generation_config.constraints is not None
or generation_config.force_words_ids is not None
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
)
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
)
elif is_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences,
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams
* generation_config.num_return_sequences,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
elif is_group_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError(
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
)
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
has_default_typical_p = (
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
)
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
if not has_default_typical_p:
raise ValueError(
"Decoder argument `typical_p` is not supported with beam groups."
)
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
elif is_constraint_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if generation_config.num_beams <= 1:
raise ValueError(
"`num_beams` needs to be greater than 1 for constrained generation."
)
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if generation_config.do_sample:
raise ValueError(
"`do_sample` needs to be false for constrained generation."
)
raise ValueError("`do_sample` needs to be false for constrained generation.")
if (
generation_config.num_beam_groups is not None
and generation_config.num_beam_groups > 1
):
raise ValueError(
"`num_beam_groups` not supported yet for constrained generation."
)
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if generation_config.constraints is not None:
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(
not isinstance(token_ids, list) for token_ids in word_ids
):
if any(not isinstance(token_ids, list) for token_ids in word_ids):
typeerror()
if any(
any(
(not isinstance(token_id, int) or token_id < 0)
for token_id in token_ids
)
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
):
typeerror()
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
else:
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(
(not isinstance(token_id, int) or token_id < 0)
for token_id in word_ids
):
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
typeerror()
constraint = PhrasalConstraint(word_ids)
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```"""
# init values
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(
stopping_criteria, max_length
)
logits_warper = (
logits_warper if logits_warper is not None else LogitsProcessorList()
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
)
# sample
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul(
(sum(next_tokens != i for i in eos_token_id)).long()
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
@ -1007,22 +885,17 @@ def init_stream_support():
if __name__ == "__main__":
from transformers import PreTrainedModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-560m", torch_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = model.to("cuda:0")
model = model.eval()
prompt_text = "hello? \n"
input_ids = tokenizer(
prompt_text, return_tensors="pt", add_special_tokens=False
).input_ids
input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
input_ids = input_ids.to("cuda:0")
with torch.no_grad():

View File

@ -0,0 +1,242 @@
import os
import random
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
torch.set_num_threads(1)
def key_samples_by_col(samples, col):
"""Returns a dictionary of samples keyed by language."""
samples_by_col = {}
for sample in samples:
col_val = sample[col]
assert isinstance(col_val, str)
if col_val not in samples_by_col:
samples_by_col[col_val] = []
samples_by_col[col_val].append(sample)
return samples_by_col
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
rel_clip = load_audio(gt_path, sample_rate)
# if eval uses a middle size sample when it is possible to be more reproducible
if is_eval:
sample_length = int((min_sample_length + max_sample_length) / 2)
else:
sample_length = random.randint(min_sample_length, max_sample_length)
gap = rel_clip.shape[-1] - sample_length
if gap < 0:
sample_length = rel_clip.shape[-1] // 2
gap = rel_clip.shape[-1] - sample_length
# if eval start always from the position 0 to be more reproducible
if is_eval:
rand_start = 0
else:
rand_start = random.randint(0, gap)
rand_end = rand_start + sample_length
rel_clip = rel_clip[:, rand_start:rand_end]
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
cond_idxs = [rand_start, rand_end]
return rel_clip, rel_clip.shape[-1], cond_idxs
def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)
# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio
class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
model_args = config.model_args
self.failed_samples = set()
self.debug_failures = model_args.debug_loading_failures
self.max_conditioning_length = model_args.max_conditioning_length
self.min_conditioning_length = model_args.min_conditioning_length
self.is_eval = is_eval
self.tokenizer = tokenizer
self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length
assert self.max_wav_len is not None and self.max_text_len is not None
self.samples = samples
if not is_eval:
random.seed(config.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
else:
# for evaluation load and check samples that are corrupted to ensures the reproducibility
self.check_eval_samples()
def check_eval_samples(self):
print(" > Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
tseq, _, wav, _, _, _ = self.load_item(sample)
except:
pass
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang)
tokens = torch.IntTensor(tokens)
assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}"
# The stop token should always be sacred.
assert not torch.any(tokens == 0), f"Stop token found in {text}"
return tokens
def load_item(self, sample):
text = str(sample["text"])
tseq = self.get_text(text, sample["language"])
audiopath = sample["audio_file"]
wav = load_audio(audiopath, self.sample_rate)
if text is None or len(text.strip()) == 0:
raise ValueError
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
# Ultra short clips are also useless (and can cause problems within some models).
raise ValueError
# get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
)
return tseq, audiopath, wav, cond, cond_len, cond_idxs
def __getitem__(self, index):
if self.is_eval:
sample = self.samples[index]
sample_id = str(index)
else:
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang + "_" + str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
if self.debug_failures:
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
# call get item again to get other sample
return self[1]
# try to load the sample, if fails added it to the failed samples list
try:
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
except:
if self.debug_failures:
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
self.failed_samples.add(sample_id)
return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None:
print(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
)
self.failed_samples.add(sample_id)
return self[1]
res = {
# 'real_text': text,
"text": tseq,
"text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
"wav": wav,
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
"filenames": audiopath,
"conditioning": cond.unsqueeze(1),
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
"cond_idxs": torch.tensor(cond_idxs),
}
return res
def __len__(self):
if self.is_eval:
return len(self.samples)
return sum([len(v) for v in self.samples.values()])
def collate_fn(self, batch):
# convert list of dicts to dict of lists
B = len(batch)
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# stack for features that already have the same shape
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
batch["text_lengths"] = torch.stack(batch["text_lengths"])
batch["conditioning"] = torch.stack(batch["conditioning"])
batch["cond_lens"] = torch.stack(batch["cond_lens"])
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
max_text_len = batch["text_lengths"].max()
max_wav_len = batch["wav_lengths"].max()
# create padding tensors
text_padded = torch.IntTensor(B, max_text_len)
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
# initialize tensors for zero padding
text_padded = text_padded.zero_()
wav_padded = wav_padded.zero_()
for i in range(B):
text = batch["text"][i]
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
wav = batch["wav"][i]
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded
batch["padded_text"] = text_padded
return batch

View File

@ -0,0 +1,473 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torchaudio
from coqpit import Coqpit
from torch.nn import functional as F
from torch.utils.data import DataLoader
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec
@dataclass
class GPTTrainerConfig(XttsConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
max_conditioning_length: int = 132300
gpt_loss_text_ce_weight: float = 0.01
gpt_loss_mel_ce_weight: float = 1.0
gpt_num_audio_tokens: int = 8194
debug_loading_failures: bool = False
max_wav_length: int = 255995 # ~11.6 seconds
max_text_length: int = 200
tokenizer_file: str = ""
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
dvae_checkpoint: str = ""
xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None
return model_info
class GPTTrainer(BaseTTS):
def __init__(self, config: Coqpit):
"""
Tortoise GPT training class
"""
super().__init__(config, ap=None, tokenizer=None)
self.config = config
# init XTTS model
self.xtts = Xtts(self.config)
# create the tokenizer with the target vocabulary
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
# init gpt encoder and hifigan decoder
self.xtts.init_models()
if self.args.xtts_checkpoint:
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!")
gpt_checkpoint = gpt_checkpoint["model"]
states_keys = list(gpt_checkpoint.keys())
for key in states_keys:
if "gpt." in key:
new_key = key.replace("gpt.", "")
gpt_checkpoint[new_key] = gpt_checkpoint[key]
del gpt_checkpoint[key]
else:
del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if (
"text_embedding.weight" in gpt_checkpoint
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
):
num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
)
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head)
emb_g = gpt_checkpoint["text_embedding.weight"]
new_row = torch.randn(num_new_tokens, emb_g.shape[1])
start_token_row = emb_g[-1, :]
emb_g = torch.cat([emb_g, new_row], axis=0)
emb_g[-1, :] = start_token_row
gpt_checkpoint["text_embedding.weight"] = emb_g
# add new weights to the linear layer (text_head)
text_head_weight = gpt_checkpoint["text_head.weight"]
start_token_row = text_head_weight[-1, :]
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
text_head_weight[-1, :] = start_token_row
gpt_checkpoint["text_head.weight"] = text_head_weight
# add new biases to the linear layer (text_head)
text_head_bias = gpt_checkpoint["text_head.bias"]
start_token_row = text_head_bias[-1]
new_bias_entry = torch.zeros(num_new_tokens)
text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
text_head_bias[-1] = start_token_row
gpt_checkpoint["text_head.bias"] = text_head_bias
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
# Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
filter_length=4096,
hop_length=1024,
win_length=4096,
normalize=False,
sampling_rate=config.audio.sample_rate,
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file,
)
# Load DVAE
self.dvae = DiscreteVAE(
channels=80,
normalization=None,
positional_dims=1,
num_tokens=self.args.gpt_num_audio_tokens - 2,
codebook_dim=512,
hidden_dim=512,
num_resnet_blocks=3,
kernel_size=3,
num_layers=2,
use_transposed_convs=False,
)
self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
else:
raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
)
# Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
)
@property
def device(self):
return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2)
"""
losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
)
return losses
@torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
if self.config.test_sentences:
# init gpt for inference mode
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.xtts.gpt.eval()
test_audios = {}
print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
)["wav"]
test_audios["{}-audio".format(idx)] = wav
# delete inference layers
del self.xtts.gpt.gpt_inference
del self.xtts.gpt.gpt.wte
return {"audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
def format_batch(self, batch: Dict) -> Dict:
return batch
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
def format_batch_on_device(self, batch):
"""Compute spectrograms on the device."""
batch["text_lengths"] = batch["text_lengths"]
batch["wav_lengths"] = batch["wav_lengths"]
batch["text_inputs"] = batch["padded_text"]
batch["cond_idxs"] = batch["cond_idxs"]
# compute conditioning mel specs
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
B, num_cond_samples, C, T = batch["conditioning"].size()
conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T)
paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped)
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
T_mel = paired_conditioning_mel.size(2)
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
# get the conditioning embeddings
batch["cond_mels"] = paired_conditioning_mel
# compute codes using DVAE
if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate:
dvae_wav = torchaudio.functional.resample(
batch["wav"],
orig_freq=self.config.audio.sample_rate,
new_freq=self.config.audio.dvae_sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492,
)
else:
dvae_wav = batch["wav"]
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
codes = self.dvae.get_codebook_indices(dvae_mel_spec)
batch["audio_codes"] = codes
# delete useless batch tensors
del batch["padded_text"]
del batch["wav"]
del batch["conditioning"]
del batch["cond_lens"]
return batch
def train_step(self, batch, criterion):
loss_dict = {}
cond_mels = batch["cond_mels"]
text_inputs = batch["text_inputs"]
text_lengths = batch["text_lengths"]
audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"]
cond_idxs = batch["cond_idxs"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
return {"model_outputs": None}, loss_dict
def eval_step(self, batch, criterion):
# ignore masking for more consistent evaluation
batch["cond_idxs"] = None
return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613
# guarante that dvae will be in eval mode after .train() on evaluation end
self.dvae = self.dvae.eval()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
if self.config.dashboard_logger.lower() == "clearml":
from clearml.binding.frameworks import WeightsFileHandler
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
@torch.no_grad()
def inference(
self,
x,
aux_input=None,
): # pylint: disable=dangerous-default-value
return None
@staticmethod
def get_criterion():
return None
def get_sampler(self, dataset: TTSDataset, num_gpus=1):
# sampler for DDP
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
return batch_sampler
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: bool,
samples: Union[List[Dict], List[List]],
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader": # pylint: disable=W0613
if is_eval and not config.run_eval:
loader = None
else:
# init dataloader
dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
# wait all the DDP process to be ready
if num_gpus > 1:
torch.distributed.barrier()
# sort input sequences from short to long
# dataset.preprocess_samples()
# get samplers
sampler = self.get_sampler(dataset, num_gpus)
# ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs
if sampler is None or is_eval:
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False,
drop_last=False,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_optimizer(self) -> List:
"""Initiate and return the optimizer based on the config parameters."""
# ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model
net = self.xtts.gpt
# normalizations
norm_modules = (
nn.BatchNorm2d,
nn.InstanceNorm2d,
nn.BatchNorm1d,
nn.InstanceNorm1d,
nn.BatchNorm3d,
nn.InstanceNorm3d,
nn.GroupNorm,
nn.LayerNorm,
)
# nn.Embedding
emb_modules = (nn.Embedding, nn.EmbeddingBag)
param_names_notweights = set()
all_param_names = set()
param_map = {}
for mn, m in net.named_modules():
for k, v in m.named_parameters():
v.is_bias = k.endswith(".bias")
v.is_weight = k.endswith(".weight")
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
fpn = "%s.%s" % (mn, k) if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb:
param_names_notweights.add(fpn)
params_names_notweights = sorted(list(param_names_notweights))
params_notweights = [param_map[k] for k in params_names_notweights]
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
params_weights = [param_map[k] for k in params_names_weights]
groups = [
{"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
{"params": params_notweights, "weight_decay": 0},
]
# torch.optim.AdamW
opt = get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
parameters=groups,
)
opt._group_names = [params_names_weights, params_names_notweights]
return opt
return get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
# optimize only for the GPT model
parameters=self.xtts.gpt.parameters(),
)
def get_scheduler(self, optimizer) -> List:
"""Set the scheduler for the optimizer.
Args:
optimizer: `torch.optim.Optimizer`.
"""
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def load_checkpoint(
self,
config,
checkpoint_path,
eval=False,
strict=True,
cache_storage="/tmp/tts_cache",
target_protocol="s3",
target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
# load the model weights
self.xtts.load_state_dict(state, strict=strict)
if eval:
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
self.eval()
assert not self.training
@staticmethod
def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (GPTTrainerConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
return GPTTrainer(config)

View File

@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
init_stream_support()
def load_audio(audiopath, sr=22050):
"""
Load an audio file from disk and resample it to the specified sampling rate.
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token,
)
if self.args.use_hifigan:
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
@ -387,7 +387,7 @@ class Xtts(BaseTTS):
audio = load_audio(audio_path)
audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)
@torch.inference_mode()
@ -414,21 +414,20 @@ class Xtts(BaseTTS):
return diffusion_latent
@torch.inference_mode()
def get_speaker_embedding(
self,
audio_path
):
def get_speaker_embedding(self, audio_path):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
audio.to(self.device), l2_norm=True
).unsqueeze(-1).to(self.device)
speaker_embedding = (
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)
return speaker_embedding
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
)
return self.inference(
text,
language,
@ -588,7 +585,7 @@ class Xtts(BaseTTS):
decoder=decoder,
**hf_generate_kwargs,
)
@torch.inference_mode()
def inference(
self,
@ -646,6 +643,7 @@ class Xtts(BaseTTS):
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
@ -666,7 +664,7 @@ class Xtts(BaseTTS):
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
@ -721,7 +719,9 @@ class Xtts(BaseTTS):
decoder="hifigan",
**hf_generate_kwargs,
):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -775,10 +775,10 @@ class Xtts(BaseTTS):
yield wav_chunk
def forward(self):
raise NotImplementedError("XTTS Training is not implemented")
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
def eval_step(self):
raise NotImplementedError("XTTS Training is not implemented")
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
@staticmethod
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
@ -789,11 +789,32 @@ class Xtts(BaseTTS):
self.gpt.init_gpt_for_inference()
super().eval()
def get_compatible_checkpoint_state_dict(self, model_path):
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
# remove xtts gpt trainer extra keys
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key
# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
return checkpoint
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
@ -822,13 +843,7 @@ class Xtts(BaseTTS):
self.init_models()
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
checkpoint = self.get_compatible_checkpoint_state_dict(model_path)
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
try:
@ -847,4 +862,4 @@ class Xtts(BaseTTS):
self.gpt.eval()
def train_step(self):
raise NotImplementedError("XTTS Training is not implemented")
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")

View File

@ -16,8 +16,8 @@ a few tricks to make it faster and support streaming inference.
Current implementation only supports inference.
### Languages
As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified).
As of now, XTTS-v1.1 supports 14 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified) and Japanese.
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
@ -33,7 +33,7 @@ You can also mail us at info@coqui.ai.
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1.1", gpu=True)
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
@ -45,7 +45,7 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
#### 🐸TTS Command line
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/speaker.wav \
--language_idx tr \
@ -134,6 +134,56 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```
### Training
A recipe for `XTTS_v1.1` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
After training you can do inference following the code bellow.
```python
import os
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# Add here the xtts_config path
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
# Add here the vocab file that you have used to train the model
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v1.1_original_model_files/vocab.json"
# Add here the checkpoint that you want to do inference with
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
# Add here the speaker reference
SPEAKER_REFERENCE = "LjSpeech_reference.wav"
# output wav path
OUTPUT_WAV_PATH = "xtts-ft.wav"
print("Loading model...")
config = XttsConfig()
config.load_json(CONFIG_PATH)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENIZER_PATH, use_deepspeed=False)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)
print("Inference...")
out = model.inference(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
diffusion_conditioning,
temperature=0.7, # Add custom parameters here
)
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
```
## Important resources & papers
- VallE: https://arxiv.org/abs/2301.02111
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts

View File

@ -0,0 +1,178 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
# Logging parameters
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./run/training/
OUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run", "training")
# Training Parameters
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
START_WITH_EVAL = True # if True it will star with evaluation
BATCH_SIZE = 3 # set here the batch size
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="ljspeech",
dataset_name="ljspeech",
path="/raid/datasets/LJSpeech-1.1_24khz/",
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
language="en",
)
# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
# Define the path where XTTS v1.1.1 files will be downloaded
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v1.1_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
# Set the path to the downloaded files
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1])
# download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
# Download XTTS v1.1 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
# download XTTS v1.1 files if needed
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
print(" > Downloading XTTS v1.1 files!")
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
# Training sentences generations
SPEAKER_REFERENCE = (
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
)
LANGUAGE = config_dataset.language
def main():
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
)
# define audio config
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
# training parameters config
config = GPTTrainerConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=10000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"speaker_wav": SPEAKER_REFERENCE,
"language": LANGUAGE,
},
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": SPEAKER_REFERENCE,
"language": LANGUAGE,
},
],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
skip_train_epoch=False,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
if __name__ == "__main__":
main()

12669
tests/inputs/xtts_vocab.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -28,7 +28,7 @@ tokenizer, config = TTSTokenizer.init_from_config(config)
def test_acoustic_model():
dummy_tokens = torch.rand((1, 41)).long().to(device)
dummy_text_lens = torch.tensor([41]).to(device)
dummy_text_lens = torch.tensor([41]).long().to(device)
dummy_spec = torch.rand((1, 100, 207)).to(device)
dummy_spec_lens = torch.tensor([207]).to(device)
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
@ -38,6 +38,7 @@ def test_acoustic_model():
args.num_mels = 100
acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
acoustic_model = acoustic_model.train()
output = acoustic_model(
tokens=dummy_tokens,
@ -51,16 +52,12 @@ def test_acoustic_model():
speaker_idx=None,
)
assert list(output["model_outputs"].shape) == [1, 207, 100]
output["model_outputs"].sum().backward()
# output["model_outputs"].sum().backward()
def test_hifi_decoder():
dummy_input = torch.rand((1, 207, 100)).to(device)
dummy_text_lens = torch.tensor([41]).to(device)
dummy_spec = torch.rand((1, 100, 207)).to(device)
dummy_spec_lens = torch.tensor([207]).to(device)
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
dummy_energy = torch.rand((1, 1, 207)).long().to(device)
waveform_decoder = HifiganGenerator(
100,
@ -77,6 +74,7 @@ def test_hifi_decoder():
conv_post_weight_norm=False,
conv_post_bias=False,
).to(device)
waveform_decoder = waveform_decoder.train()
vocoder_input_slices, slice_ids = rand_segments( # pylint: disable=unused-variable
x=dummy_input.transpose(1, 2),
@ -88,4 +86,4 @@ def test_hifi_decoder():
outputs = waveform_decoder(x=vocoder_input_slices.detach())
assert list(outputs.shape) == [1, 1, 8192]
outputs.sum().backward()
# outputs.sum().backward()

View File

@ -0,0 +1,163 @@
import os
import shutil
import torch
from trainer import Trainer, TrainerArgs
from tests import get_tests_output_path
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
config_dataset = BaseDatasetConfig(
formatter="ljspeech",
dataset_name="ljspeech",
path="tests/data/ljspeech/",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
language="en",
)
DATASETS_CONFIG_LIST = [config_dataset]
# Logging parameters
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./run/training/
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
os.makedirs(OUT_PATH, exist_ok=True)
# Create DVAE checkpoint and mel_norms on test time
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
MEL_NORM_FILE = os.path.join(
OUT_PATH, "mel_stats.pth"
) # Mel spectrogram norms, required for dvae mel spectrogram extraction
dvae = DiscreteVAE(
channels=80,
normalization=None,
positional_dims=1,
num_tokens=8192,
codebook_dim=512,
hidden_dim=512,
num_resnet_blocks=3,
kernel_size=3,
num_layers=2,
use_transposed_convs=False,
)
torch.save(dvae.state_dict(), DVAE_CHECKPOINT)
mel_stats = torch.ones(80)
torch.save(mel_stats, MEL_NORM_FILE)
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
TOKENIZER_FILE = "tests/inputs/xtts_vocab.json" # vocab.json file
XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
# Training sentences generations
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
LANGUAGE = config_dataset.language
# Training Parameters
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
START_WITH_EVAL = False # if True it will star with evaluation
BATCH_SIZE = 2 # set here the batch size
GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
)
config = GPTTrainerConfig(
epochs=1,
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=10000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": SPEAKER_REFERENCE,
"language": LANGUAGE,
},
],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
skip_train_epoch=False,
start_with_eval=True,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
# remove output path
shutil.rmtree(OUT_PATH)

View File

@ -99,6 +99,7 @@ def test_xtts_streaming():
"""Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
@ -115,7 +116,7 @@ def test_xtts_streaming():
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
speaker_embedding,
)
wav_chuncks = []
for i, chunk in enumerate(chunks):