mirror of https://github.com/coqui-ai/TTS.git
commit
9c68992ccc
|
@ -0,0 +1,53 @@
|
|||
name: xtts-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
jobs:
|
||||
check_skip:
|
||||
runs-on: ubuntu-latest
|
||||
if: "! contains(github.event.head_commit.message, '[ci skip]')"
|
||||
steps:
|
||||
- run: echo "${{ github.event.head_commit.message }}"
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.9, "3.10", "3.11"]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
architecture: x64
|
||||
cache: 'pip'
|
||||
cache-dependency-path: 'requirements*'
|
||||
- name: check OS
|
||||
run: cat /etc/os-release
|
||||
- name: set ENV
|
||||
run: export TRAINER_TELEMETRY=0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y --no-install-recommends git make gcc
|
||||
sudo apt-get install espeak
|
||||
sudo apt-get install espeak-ng
|
||||
make system-deps
|
||||
- name: Install/upgrade Python setup deps
|
||||
run: python3 -m pip install --upgrade pip setuptools wheel
|
||||
- name: Replace scarf urls
|
||||
run: |
|
||||
sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
|
||||
- name: Install TTS
|
||||
run: |
|
||||
python3 -m pip install .[all]
|
||||
python3 setup.py egg_info
|
||||
- name: Unit tests
|
||||
run: make test_xtts
|
3
Makefile
3
Makefile
|
@ -22,6 +22,9 @@ test_tts: ## run tts tests.
|
|||
test_tts2: ## run tts tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.tts_tests2
|
||||
|
||||
test_xtts:
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.xtts_tests
|
||||
|
||||
test_aux: ## run aux tests.
|
||||
nose2 -F -v -B --with-coverage --coverage TTS tests.aux_tests
|
||||
./run_bash_tests.sh
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.18.2
|
||||
0.19.0
|
||||
|
|
|
@ -197,6 +197,7 @@ class GPT(nn.Module):
|
|||
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
|
||||
self.ds_engine = deepspeed.init_inference(
|
||||
model=self.gpt_inference.half(), # Transformers models
|
||||
mp_size=1, # Number of GPU
|
||||
|
@ -233,6 +234,7 @@ class GPT(nn.Module):
|
|||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_cond=None,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
|
@ -248,8 +250,11 @@ class GPT(nn.Module):
|
|||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
if attn_mask_cond is not None:
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
else:
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
|
@ -326,7 +331,7 @@ class GPT(nn.Module):
|
|||
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
|
||||
def get_style_emb(self, cond_input, return_latent=False):
|
||||
"""
|
||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||
conds: (b, 1024, s)
|
||||
|
@ -335,26 +340,7 @@ class GPT(nn.Module):
|
|||
if not return_latent:
|
||||
if cond_input.ndim == 4:
|
||||
cond_input = cond_input.squeeze(1)
|
||||
if sample:
|
||||
_len_secs = random.randint(2, 6) # in secs
|
||||
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
|
||||
if cond_input.shape[-1] >= cond_seg_len:
|
||||
new_conds = []
|
||||
for i in range(cond_input.shape[0]):
|
||||
cond_len = int(cond_lens[i] / 1024)
|
||||
if cond_len < cond_seg_len:
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, cond_len - cond_seg_len)
|
||||
cond_vec = cond_input[i, :, start : start + cond_seg_len]
|
||||
new_conds.append(cond_vec)
|
||||
conds = torch.stack(new_conds, dim=0)
|
||||
else:
|
||||
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
|
||||
cond_frame_len = int((22050 / 1024) * cond_seg_len)
|
||||
conds = cond_input[:, :, -cond_frame_len:]
|
||||
|
||||
conds = self.conditioning_encoder(conds)
|
||||
conds = self.conditioning_encoder(cond_input)
|
||||
else:
|
||||
# already computed
|
||||
conds = cond_input.unsqueeze(1)
|
||||
|
@ -366,10 +352,9 @@ class GPT(nn.Module):
|
|||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
cond_lens=None,
|
||||
cond_mels=None,
|
||||
cond_idxs=None,
|
||||
cond_latents=None,
|
||||
loss_weights=None,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
|
@ -377,11 +362,12 @@ class GPT(nn.Module):
|
|||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
cond_mels: MEL float tensor, (b, 1, 80,s)
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, 1, 80,s)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
|
||||
If return_attentions is specified, only logits are returned.
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
|
@ -393,6 +379,11 @@ class GPT(nn.Module):
|
|||
max_text_len = text_lengths.max()
|
||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||
|
||||
if cond_idxs is not None:
|
||||
# recompute cond idxs for mel lengths
|
||||
for idx, l in enumerate(code_lengths):
|
||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
||||
|
@ -435,9 +426,16 @@ class GPT(nn.Module):
|
|||
)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_cond = None
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_cond = torch.ones(
|
||||
cond_mels.shape[0],
|
||||
cond_mels.shape[-1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
|
@ -451,6 +449,11 @@ class GPT(nn.Module):
|
|||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
if cond_idxs is not None:
|
||||
for idx, r in enumerate(cond_idxs):
|
||||
l = r[1] - r[0]
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
|
@ -465,7 +468,7 @@ class GPT(nn.Module):
|
|||
|
||||
# Compute speech conditioning input
|
||||
if cond_latents is None:
|
||||
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
|
||||
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
||||
|
||||
# Get logits
|
||||
sub = -5 # don't ask me why 😄
|
||||
|
@ -480,6 +483,7 @@ class GPT(nn.Module):
|
|||
prompt=cond_latents,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_cond=attn_mask_cond,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
|
@ -501,6 +505,13 @@ class GPT(nn.Module):
|
|||
0
|
||||
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
|
||||
|
||||
# ignore the loss for the segment used for conditioning
|
||||
# coin flip for the segment to be ignored
|
||||
if cond_idxs is not None:
|
||||
cond_start = cond_idxs[idx, 0]
|
||||
cond_end = cond_idxs[idx, 1]
|
||||
mel_targets[idx, cond_start:cond_end] = -1
|
||||
|
||||
# Compute losses
|
||||
loss_text = F.cross_entropy(
|
||||
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
|
||||
|
@ -548,7 +559,7 @@ class GPT(nn.Module):
|
|||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
max_length=self.max_mel_tokens,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
|
@ -561,7 +572,7 @@ class GPT(nn.Module):
|
|||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
max_length=self.max_mel_tokens,
|
||||
do_stream=True,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
import torchaudio
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
|
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
|
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
|
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
|
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
|
|||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies.
|
||||
"""
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies."""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
|
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return criterion, state["step"]
|
||||
return criterion
|
||||
|
||||
|
||||
class HifiDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
|
|||
"""
|
||||
return self.forward(c, g=g)
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# remove unused keys
|
||||
state = state["model"]
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
BeamSearchScorer,
|
||||
ConstrainedBeamSearchScorer,
|
||||
DisjunctiveConstraint,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
DisjunctiveConstraint,
|
||||
BeamSearchScorer,
|
||||
PhrasalConstraint,
|
||||
ConstrainedBeamSearchScorer,
|
||||
PreTrainedModel,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
import inspect
|
||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||
import torch
|
||||
from typing import Callable, List, Optional, Union
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
import copy
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config: Optional[StreamGenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[
|
||||
Callable[[int, torch.Tensor], List[int]]
|
||||
] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
|
@ -125,7 +124,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
"""
|
||||
#setup_seed(seed)
|
||||
# setup_seed(seed)
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
|
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||
# model attribute accordingly, if it was created from the model config
|
||||
if self.generation_config._from_model_config:
|
||||
new_generation_config = StreamGenerationConfig.from_model_config(
|
||||
self.config
|
||||
)
|
||||
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
|
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config = self.generation_config
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(
|
||||
**kwargs
|
||||
) # All unused kwargs must be model kwargs
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
# self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if (
|
||||
generation_config.pad_token_id is None
|
||||
and generation_config.eos_token_id is not None
|
||||
):
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
|
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
|
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
|
|||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(
|
||||
inspect.signature(self.forward).parameters.keys()
|
||||
)
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None
|
||||
and requires_attention_mask
|
||||
and accepts_attention_mask
|
||||
):
|
||||
model_kwargs[
|
||||
"attention_mask"
|
||||
] = self._prepare_attention_mask_for_generation(
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
|
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if not self.config.is_encoder_decoder:
|
||||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
||||
> 0
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
|
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
has_default_max_length = (
|
||||
kwargs.get("max_length") is None
|
||||
and generation_config.max_length is not None
|
||||
)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
|
||||
|
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
|
|||
UserWarning,
|
||||
)
|
||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
elif (
|
||||
not has_default_max_length and generation_config.max_new_tokens is not None
|
||||
):
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
||||
raise ValueError(
|
||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
||||
|
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
|
|||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.min_length > generation_config.max_length
|
||||
):
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = (
|
||||
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
)
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
|
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 7. determine generation mode
|
||||
is_constraint_gen_mode = (
|
||||
generation_config.constraints is not None
|
||||
or generation_config.force_words_ids is not None
|
||||
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
||||
)
|
||||
|
||||
is_contrastive_search_gen_mode = (
|
||||
|
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
|
||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_beam_groups` has to be smaller or equal to `num_beams`"
|
||||
)
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
|
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
elif is_beam_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
|
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
# 12. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size * generation_config.num_return_sequences,
|
||||
|
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams
|
||||
* generation_config.num_return_sequences,
|
||||
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
elif is_group_beam_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
||||
raise ValueError(
|
||||
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
|
||||
)
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
has_default_typical_p = (
|
||||
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||
)
|
||||
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||
if not has_default_typical_p:
|
||||
raise ValueError(
|
||||
"Decoder argument `typical_p` is not supported with beam groups."
|
||||
)
|
||||
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
|
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
elif is_constraint_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
if generation_config.num_beams <= 1:
|
||||
raise ValueError(
|
||||
"`num_beams` needs to be greater than 1 for constrained generation."
|
||||
)
|
||||
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
|
||||
|
||||
if generation_config.do_sample:
|
||||
raise ValueError(
|
||||
"`do_sample` needs to be false for constrained generation."
|
||||
)
|
||||
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
||||
|
||||
if (
|
||||
generation_config.num_beam_groups is not None
|
||||
and generation_config.num_beam_groups > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"`num_beam_groups` not supported yet for constrained generation."
|
||||
)
|
||||
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
|
||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||
|
||||
final_constraints = []
|
||||
if generation_config.constraints is not None:
|
||||
|
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if isinstance(word_ids[0], list):
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any(
|
||||
not isinstance(token_ids, list) for token_ids in word_ids
|
||||
):
|
||||
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
||||
typeerror()
|
||||
if any(
|
||||
any(
|
||||
(not isinstance(token_id, int) or token_id < 0)
|
||||
for token_id in token_ids
|
||||
)
|
||||
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||
for token_ids in word_ids
|
||||
):
|
||||
typeerror()
|
||||
|
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
else:
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any(
|
||||
(not isinstance(token_id, int) or token_id < 0)
|
||||
for token_id in word_ids
|
||||
):
|
||||
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
||||
typeerror()
|
||||
|
||||
constraint = PhrasalConstraint(word_ids)
|
||||
|
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
|
|||
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use"
|
||||
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(
|
||||
stopping_criteria, max_length
|
||||
)
|
||||
logits_warper = (
|
||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
)
|
||||
pad_token_id = (
|
||||
pad_token_id
|
||||
if pad_token_id is not None
|
||||
else self.generation_config.pad_token_id
|
||||
)
|
||||
eos_token_id = (
|
||||
eos_token_id
|
||||
if eos_token_id is not None
|
||||
else self.generation_config.eos_token_id
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = (
|
||||
output_scores
|
||||
if output_scores is not None
|
||||
else self.generation_config.output_scores
|
||||
)
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.generation_config.output_attentions
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.generation_config.output_hidden_states
|
||||
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
|
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(
|
||||
0.0 if this_peer_finished else 1.0
|
||||
).to(input_ids.device)
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
|
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
|
|||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.attentions,)
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# sample
|
||||
|
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
(sum(next_tokens != i for i in eos_token_id)).long()
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
|
@ -1007,22 +885,17 @@ def init_stream_support():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
|
||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"bigscience/bloom-560m", torch_dtype=torch.float16
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
model = model.to("cuda:0")
|
||||
model = model.eval()
|
||||
prompt_text = "hello? \n"
|
||||
input_ids = tokenizer(
|
||||
prompt_text, return_tensors="pt", add_special_tokens=False
|
||||
).input_ids
|
||||
input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
|
||||
input_ids = input_ids.to("cuda:0")
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def key_samples_by_col(samples, col):
|
||||
"""Returns a dictionary of samples keyed by language."""
|
||||
samples_by_col = {}
|
||||
for sample in samples:
|
||||
col_val = sample[col]
|
||||
assert isinstance(col_val, str)
|
||||
if col_val not in samples_by_col:
|
||||
samples_by_col[col_val] = []
|
||||
samples_by_col[col_val].append(sample)
|
||||
return samples_by_col
|
||||
|
||||
|
||||
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
|
||||
rel_clip = load_audio(gt_path, sample_rate)
|
||||
# if eval uses a middle size sample when it is possible to be more reproducible
|
||||
if is_eval:
|
||||
sample_length = int((min_sample_length + max_sample_length) / 2)
|
||||
else:
|
||||
sample_length = random.randint(min_sample_length, max_sample_length)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
if gap < 0:
|
||||
sample_length = rel_clip.shape[-1] // 2
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
|
||||
# if eval start always from the position 0 to be more reproducible
|
||||
if is_eval:
|
||||
rand_start = 0
|
||||
else:
|
||||
rand_start = random.randint(0, gap)
|
||||
|
||||
rand_end = rand_start + sample_length
|
||||
rel_clip = rel_clip[:, rand_start:rand_end]
|
||||
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
||||
cond_idxs = [rand_start, rand_end]
|
||||
return rel_clip, rel_clip.shape[-1], cond_idxs
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio_sox_load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio_soundfile_load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||
self.config = config
|
||||
model_args = config.model_args
|
||||
self.failed_samples = set()
|
||||
self.debug_failures = model_args.debug_loading_failures
|
||||
self.max_conditioning_length = model_args.max_conditioning_length
|
||||
self.min_conditioning_length = model_args.min_conditioning_length
|
||||
self.is_eval = is_eval
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_rate = sample_rate
|
||||
self.max_wav_len = model_args.max_wav_length
|
||||
self.max_text_len = model_args.max_text_length
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
|
||||
self.samples = samples
|
||||
if not is_eval:
|
||||
random.seed(config.training_seed)
|
||||
# random.shuffle(self.samples)
|
||||
random.shuffle(self.samples)
|
||||
# order by language
|
||||
self.samples = key_samples_by_col(self.samples, "language")
|
||||
print(" > Sampling by language:", self.samples.keys())
|
||||
else:
|
||||
# for evaluation load and check samples that are corrupted to ensures the reproducibility
|
||||
self.check_eval_samples()
|
||||
|
||||
def check_eval_samples(self):
|
||||
print(" > Filtering invalid eval samples!!")
|
||||
new_samples = []
|
||||
for sample in self.samples:
|
||||
try:
|
||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||
except:
|
||||
pass
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
if (
|
||||
wav is None
|
||||
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||
):
|
||||
continue
|
||||
new_samples.append(sample)
|
||||
self.samples = new_samples
|
||||
print(" > Total eval samples after filtering:", len(self.samples))
|
||||
|
||||
def get_text(self, text, lang):
|
||||
tokens = self.tokenizer.encode(text, lang)
|
||||
tokens = torch.IntTensor(tokens)
|
||||
assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}"
|
||||
# The stop token should always be sacred.
|
||||
assert not torch.any(tokens == 0), f"Stop token found in {text}"
|
||||
return tokens
|
||||
|
||||
def load_item(self, sample):
|
||||
text = str(sample["text"])
|
||||
tseq = self.get_text(text, sample["language"])
|
||||
audiopath = sample["audio_file"]
|
||||
wav = load_audio(audiopath, self.sample_rate)
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
|
||||
# Ultra short clips are also useless (and can cause problems within some models).
|
||||
raise ValueError
|
||||
|
||||
# get a slice from GT to condition the model
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||
)
|
||||
|
||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_eval:
|
||||
sample = self.samples[index]
|
||||
sample_id = str(index)
|
||||
else:
|
||||
# select a random language
|
||||
lang = random.choice(list(self.samples.keys()))
|
||||
# select random sample
|
||||
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||
sample = self.samples[lang][index]
|
||||
# a unique id for each sampel to deal with fails
|
||||
sample_id = lang + "_" + str(index)
|
||||
|
||||
# ignore samples that we already know that is not valid ones
|
||||
if sample_id in self.failed_samples:
|
||||
if self.debug_failures:
|
||||
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
|
||||
# call get item again to get other sample
|
||||
return self[1]
|
||||
|
||||
# try to load the sample, if fails added it to the failed samples list
|
||||
try:
|
||||
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
|
||||
except:
|
||||
if self.debug_failures:
|
||||
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
||||
if (
|
||||
wav is None
|
||||
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||
):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures and wav is not None and tseq is not None:
|
||||
print(
|
||||
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
|
||||
)
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
res = {
|
||||
# 'real_text': text,
|
||||
"text": tseq,
|
||||
"text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||
"wav": wav,
|
||||
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
"filenames": audiopath,
|
||||
"conditioning": cond.unsqueeze(1),
|
||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
||||
"cond_idxs": torch.tensor(cond_idxs),
|
||||
}
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
if self.is_eval:
|
||||
return len(self.samples)
|
||||
return sum([len(v) for v in self.samples.values()])
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# convert list of dicts to dict of lists
|
||||
B = len(batch)
|
||||
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
# stack for features that already have the same shape
|
||||
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
|
||||
batch["text_lengths"] = torch.stack(batch["text_lengths"])
|
||||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||
max_text_len = batch["text_lengths"].max()
|
||||
max_wav_len = batch["wav_lengths"].max()
|
||||
|
||||
# create padding tensors
|
||||
text_padded = torch.IntTensor(B, max_text_len)
|
||||
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
|
||||
|
||||
# initialize tensors for zero padding
|
||||
text_padded = text_padded.zero_()
|
||||
wav_padded = wav_padded.zero_()
|
||||
for i in range(B):
|
||||
text = batch["text"][i]
|
||||
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
|
||||
wav = batch["wav"][i]
|
||||
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
||||
|
||||
batch["wav"] = wav_padded
|
||||
batch["padded_text"] = text_padded
|
||||
return batch
|
|
@ -0,0 +1,473 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTTrainerConfig(XttsConfig):
|
||||
lr: float = 5e-06
|
||||
training_seed: int = 1
|
||||
optimizer_wd_only_on_weights: bool = False
|
||||
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(XttsAudioConfig):
|
||||
dvae_sample_rate: int = 22050
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTArgs(XttsArgs):
|
||||
min_conditioning_length: int = 66150
|
||||
max_conditioning_length: int = 132300
|
||||
gpt_loss_text_ce_weight: float = 0.01
|
||||
gpt_loss_mel_ce_weight: float = 1.0
|
||||
gpt_num_audio_tokens: int = 8194
|
||||
debug_loading_failures: bool = False
|
||||
max_wav_length: int = 255995 # ~11.6 seconds
|
||||
max_text_length: int = 200
|
||||
tokenizer_file: str = ""
|
||||
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
||||
dvae_checkpoint: str = ""
|
||||
xtts_checkpoint: str = ""
|
||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||
|
||||
|
||||
def callback_clearml_load_save(operation_type, model_info):
|
||||
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
||||
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
||||
assert operation_type in ("load", "save")
|
||||
# print(operation_type, model_info.__dict__)
|
||||
|
||||
if "similarities.pth" in model_info.__dict__["local_model_path"]:
|
||||
return None
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
class GPTTrainer(BaseTTS):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""
|
||||
Tortoise GPT training class
|
||||
"""
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.config = config
|
||||
# init XTTS model
|
||||
self.xtts = Xtts(self.config)
|
||||
# create the tokenizer with the target vocabulary
|
||||
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||
# init gpt encoder and hifigan decoder
|
||||
self.xtts.init_models()
|
||||
|
||||
if self.args.xtts_checkpoint:
|
||||
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
||||
|
||||
# set mel stats
|
||||
if self.args.mel_norm_file:
|
||||
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||
|
||||
# load GPT if available
|
||||
if self.args.gpt_checkpoint:
|
||||
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
|
||||
# deal with coqui Trainer exported model
|
||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||
print("Coqui Trainer checkpoint detected! Converting it!")
|
||||
gpt_checkpoint = gpt_checkpoint["model"]
|
||||
states_keys = list(gpt_checkpoint.keys())
|
||||
for key in states_keys:
|
||||
if "gpt." in key:
|
||||
new_key = key.replace("gpt.", "")
|
||||
gpt_checkpoint[new_key] = gpt_checkpoint[key]
|
||||
del gpt_checkpoint[key]
|
||||
else:
|
||||
del gpt_checkpoint[key]
|
||||
|
||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||
if (
|
||||
"text_embedding.weight" in gpt_checkpoint
|
||||
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
|
||||
):
|
||||
num_new_tokens = (
|
||||
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
)
|
||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||
|
||||
# add new tokens to a linear layer (text_head)
|
||||
emb_g = gpt_checkpoint["text_embedding.weight"]
|
||||
new_row = torch.randn(num_new_tokens, emb_g.shape[1])
|
||||
start_token_row = emb_g[-1, :]
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
emb_g[-1, :] = start_token_row
|
||||
gpt_checkpoint["text_embedding.weight"] = emb_g
|
||||
|
||||
# add new weights to the linear layer (text_head)
|
||||
text_head_weight = gpt_checkpoint["text_head.weight"]
|
||||
start_token_row = text_head_weight[-1, :]
|
||||
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
|
||||
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
||||
text_head_weight[-1, :] = start_token_row
|
||||
gpt_checkpoint["text_head.weight"] = text_head_weight
|
||||
|
||||
# add new biases to the linear layer (text_head)
|
||||
text_head_bias = gpt_checkpoint["text_head.bias"]
|
||||
start_token_row = text_head_bias[-1]
|
||||
new_bias_entry = torch.zeros(num_new_tokens)
|
||||
text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
|
||||
text_head_bias[-1] = start_token_row
|
||||
gpt_checkpoint["text_head.bias"] = text_head_bias
|
||||
|
||||
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
|
||||
# Load DVAE
|
||||
self.dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
positional_dims=1,
|
||||
num_tokens=self.args.gpt_num_audio_tokens - 2,
|
||||
codebook_dim=512,
|
||||
hidden_dim=512,
|
||||
num_resnet_blocks=3,
|
||||
kernel_size=3,
|
||||
num_layers=2,
|
||||
use_transposed_convs=False,
|
||||
)
|
||||
|
||||
self.dvae.eval()
|
||||
if self.args.dvae_checkpoint:
|
||||
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
|
||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
|
||||
)
|
||||
|
||||
# Mel spectrogram extractor for DVAE
|
||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
|
||||
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
"""
|
||||
losses = self.xtts.gpt(
|
||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
||||
)
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
|
||||
if self.config.test_sentences:
|
||||
# init gpt for inference mode
|
||||
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||
self.xtts.gpt.eval()
|
||||
test_audios = {}
|
||||
print(" | > Synthesizing test sentences.")
|
||||
for idx, s_info in enumerate(self.config.test_sentences):
|
||||
wav = self.xtts.synthesize(
|
||||
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
|
||||
)["wav"]
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
|
||||
# delete inference layers
|
||||
del self.xtts.gpt.gpt_inference
|
||||
del self.xtts.gpt.gpt.wte
|
||||
return {"audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
return batch
|
||||
|
||||
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
|
||||
def format_batch_on_device(self, batch):
|
||||
"""Compute spectrograms on the device."""
|
||||
batch["text_lengths"] = batch["text_lengths"]
|
||||
batch["wav_lengths"] = batch["wav_lengths"]
|
||||
batch["text_inputs"] = batch["padded_text"]
|
||||
batch["cond_idxs"] = batch["cond_idxs"]
|
||||
# compute conditioning mel specs
|
||||
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
|
||||
B, num_cond_samples, C, T = batch["conditioning"].size()
|
||||
conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T)
|
||||
paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped)
|
||||
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
|
||||
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
|
||||
T_mel = paired_conditioning_mel.size(2)
|
||||
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
|
||||
# get the conditioning embeddings
|
||||
batch["cond_mels"] = paired_conditioning_mel
|
||||
# compute codes using DVAE
|
||||
if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate:
|
||||
dvae_wav = torchaudio.functional.resample(
|
||||
batch["wav"],
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.config.audio.dvae_sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method="kaiser_window",
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
else:
|
||||
dvae_wav = batch["wav"]
|
||||
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
|
||||
codes = self.dvae.get_codebook_indices(dvae_mel_spec)
|
||||
|
||||
batch["audio_codes"] = codes
|
||||
# delete useless batch tensors
|
||||
del batch["padded_text"]
|
||||
del batch["wav"]
|
||||
del batch["conditioning"]
|
||||
del batch["cond_lens"]
|
||||
return batch
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
loss_dict = {}
|
||||
cond_mels = batch["cond_mels"]
|
||||
text_inputs = batch["text_inputs"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
audio_codes = batch["audio_codes"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
cond_idxs = batch["cond_idxs"]
|
||||
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||
return {"model_outputs": None}, loss_dict
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
# ignore masking for more consistent evaluation
|
||||
batch["cond_idxs"] = None
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
||||
self.dvae = self.dvae.eval()
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
# ignore similarities.pth on clearml save/upload
|
||||
if self.config.dashboard_logger.lower() == "clearml":
|
||||
from clearml.binding.frameworks import WeightsFileHandler
|
||||
|
||||
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
aux_input=None,
|
||||
): # pylint: disable=dangerous-default-value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return None
|
||||
|
||||
def get_sampler(self, dataset: TTSDataset, num_gpus=1):
|
||||
# sampler for DDP
|
||||
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
return batch_sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader": # pylint: disable=W0613
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# init dataloader
|
||||
dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
# dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(dataset, num_gpus)
|
||||
|
||||
# ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs
|
||||
if sampler is None or is_eval:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
else:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the optimizer based on the config parameters."""
|
||||
# ToDo: deal with multi GPU training
|
||||
if self.config.optimizer_wd_only_on_weights:
|
||||
# parameters to only GPT model
|
||||
net = self.xtts.gpt
|
||||
|
||||
# normalizations
|
||||
norm_modules = (
|
||||
nn.BatchNorm2d,
|
||||
nn.InstanceNorm2d,
|
||||
nn.BatchNorm1d,
|
||||
nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d,
|
||||
nn.InstanceNorm3d,
|
||||
nn.GroupNorm,
|
||||
nn.LayerNorm,
|
||||
)
|
||||
# nn.Embedding
|
||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||
|
||||
param_names_notweights = set()
|
||||
all_param_names = set()
|
||||
param_map = {}
|
||||
for mn, m in net.named_modules():
|
||||
for k, v in m.named_parameters():
|
||||
v.is_bias = k.endswith(".bias")
|
||||
v.is_weight = k.endswith(".weight")
|
||||
v.is_norm = isinstance(m, norm_modules)
|
||||
v.is_emb = isinstance(m, emb_modules)
|
||||
|
||||
fpn = "%s.%s" % (mn, k) if mn else k # full param name
|
||||
all_param_names.add(fpn)
|
||||
param_map[fpn] = v
|
||||
if v.is_bias or v.is_norm or v.is_emb:
|
||||
param_names_notweights.add(fpn)
|
||||
|
||||
params_names_notweights = sorted(list(param_names_notweights))
|
||||
params_notweights = [param_map[k] for k in params_names_notweights]
|
||||
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
|
||||
params_weights = [param_map[k] for k in params_names_weights]
|
||||
|
||||
groups = [
|
||||
{"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
|
||||
{"params": params_notweights, "weight_decay": 0},
|
||||
]
|
||||
# torch.optim.AdamW
|
||||
opt = get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
parameters=groups,
|
||||
)
|
||||
opt._group_names = [params_names_weights, params_names_notweights]
|
||||
return opt
|
||||
|
||||
return get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
# optimize only for the GPT model
|
||||
parameters=self.xtts.gpt.parameters(),
|
||||
)
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the scheduler for the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer: `torch.optim.Optimizer`.
|
||||
"""
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_path,
|
||||
eval=False,
|
||||
strict=True,
|
||||
cache_storage="/tmp/tts_cache",
|
||||
target_protocol="s3",
|
||||
target_options={"anon": True},
|
||||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
|
||||
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
|
||||
|
||||
# load the model weights
|
||||
self.xtts.load_state_dict(state, strict=strict)
|
||||
|
||||
if eval:
|
||||
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (GPTTrainerConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
return GPTTrainer(config)
|
|
@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
|
|||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
init_stream_support()
|
||||
|
||||
|
||||
def load_audio(audiopath, sr=22050):
|
||||
"""
|
||||
Load an audio file from disk and resample it to the specified sampling rate.
|
||||
|
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
|
|||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
)
|
||||
|
||||
|
||||
if self.args.use_hifigan:
|
||||
self.hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
|
@ -387,7 +387,7 @@ class Xtts(BaseTTS):
|
|||
audio = load_audio(audio_path)
|
||||
audio = audio[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -414,21 +414,20 @@ class Xtts(BaseTTS):
|
|||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(
|
||||
self,
|
||||
audio_path
|
||||
):
|
||||
def get_speaker_embedding(self, audio_path):
|
||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
speaker_embedding = (
|
||||
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
|
||||
.unsqueeze(-1)
|
||||
.to(self.device)
|
||||
)
|
||||
return speaker_embedding
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
gpt_cond_len=3,
|
||||
):
|
||||
):
|
||||
speaker_embedding = None
|
||||
diffusion_cond_latents = None
|
||||
if self.args.use_hifigan:
|
||||
|
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
|
|||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
||||
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
|
||||
)
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
|
@ -588,7 +585,7 @@ class Xtts(BaseTTS):
|
|||
decoder=decoder,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
|
@ -646,6 +643,7 @@ class Xtts(BaseTTS):
|
|||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = self.gpt(
|
||||
text_tokens,
|
||||
|
@ -666,7 +664,7 @@ class Xtts(BaseTTS):
|
|||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
|
||||
if decoder == "hifigan":
|
||||
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
|
||||
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||
|
@ -721,7 +719,9 @@ class Xtts(BaseTTS):
|
|||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
assert hasattr(
|
||||
self, "hifigan_decoder"
|
||||
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -775,10 +775,10 @@ class Xtts(BaseTTS):
|
|||
yield wav_chunk
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
|
||||
def eval_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
||||
|
@ -789,11 +789,32 @@ class Xtts(BaseTTS):
|
|||
self.gpt.init_gpt_for_inference()
|
||||
super().eval()
|
||||
|
||||
def get_compatible_checkpoint_state_dict(self, model_path):
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
# remove xtts gpt trainer extra keys
|
||||
ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
||||
for key in list(checkpoint.keys()):
|
||||
# check if it is from the coqui Trainer if so convert it
|
||||
if key.startswith("xtts."):
|
||||
new_key = key.replace("xtts.", "")
|
||||
checkpoint[new_key] = checkpoint[key]
|
||||
del checkpoint[key]
|
||||
key = new_key
|
||||
|
||||
# remove unused keys
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
|
||||
return checkpoint
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
checkpoint_path=None,
|
||||
checkpoint_path=None,
|
||||
vocab_path=None,
|
||||
eval=True,
|
||||
strict=True,
|
||||
|
@ -822,13 +843,7 @@ class Xtts(BaseTTS):
|
|||
|
||||
self.init_models()
|
||||
|
||||
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
|
||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
for key in list(checkpoint.keys()):
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
checkpoint = self.get_compatible_checkpoint_state_dict(model_path)
|
||||
|
||||
# deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
|
||||
try:
|
||||
|
@ -847,4 +862,4 @@ class Xtts(BaseTTS):
|
|||
self.gpt.eval()
|
||||
|
||||
def train_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
raise NotImplementedError("XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training")
|
||||
|
|
|
@ -16,8 +16,8 @@ a few tricks to make it faster and support streaming inference.
|
|||
Current implementation only supports inference.
|
||||
|
||||
### Languages
|
||||
As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese,
|
||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified).
|
||||
As of now, XTTS-v1.1 supports 14 languages: English, Spanish, French, German, Italian, Portuguese,
|
||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, Chinese (Simplified) and Japanese.
|
||||
|
||||
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
|
||||
|
||||
|
@ -33,7 +33,7 @@ You can also mail us at info@coqui.ai.
|
|||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1.1", gpu=True)
|
||||
|
||||
# generate speech by cloning a voice using default settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
|
@ -45,7 +45,7 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
|
|||
#### 🐸TTS Command line
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1.1 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
--speaker_wav /path/to/target/speaker.wav \
|
||||
--language_idx tr \
|
||||
|
@ -134,6 +134,56 @@ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
|||
```
|
||||
|
||||
|
||||
### Training
|
||||
|
||||
A recipe for `XTTS_v1.1` GPT encoder training using `LJSpeech` dataset is available at https://github.com/coqui-ai/TTS/tree/dev/recipes/ljspeech/xtts_v1/train_gpt_xtts.py
|
||||
|
||||
You need to change the fields of the `BaseDatasetConfig` to match your dataset and then update `GPTArgs` and `GPTTrainerConfig` fields as you need. By default, it will use the same parameters that XTTS v1.1 model was trained with. To speed up the model convergence, as default, it will also download the XTTS v1.1 checkpoint and load it.
|
||||
|
||||
After training you can do inference following the code bellow.
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
import torchaudio
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
# Add here the xtts_config path
|
||||
CONFIG_PATH = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT-October-23-2023_10+36AM-653f2e75/config.json"
|
||||
# Add here the vocab file that you have used to train the model
|
||||
TOKENIZER_PATH = "recipes/ljspeech/xtts_v1/run/training/XTTS_v1.1_original_model_files/vocab.json"
|
||||
# Add here the checkpoint that you want to do inference with
|
||||
XTTS_CHECKPOINT = "recipes/ljspeech/xtts_v1/run/training/GPT_XTTS_LJSpeech_FT/best_model.pth"
|
||||
# Add here the speaker reference
|
||||
SPEAKER_REFERENCE = "LjSpeech_reference.wav"
|
||||
|
||||
# output wav path
|
||||
OUTPUT_WAV_PATH = "xtts-ft.wav"
|
||||
|
||||
print("Loading model...")
|
||||
config = XttsConfig()
|
||||
config.load_json(CONFIG_PATH)
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_path=XTTS_CHECKPOINT, vocab_path=TOKENIZER_PATH, use_deepspeed=False)
|
||||
model.cuda()
|
||||
|
||||
print("Computing speaker latents...")
|
||||
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=SPEAKER_REFERENCE)
|
||||
|
||||
print("Inference...")
|
||||
out = model.inference(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding,
|
||||
diffusion_conditioning,
|
||||
temperature=0.7, # Add custom parameters here
|
||||
)
|
||||
torchaudio.save(OUTPUT_WAV_PATH, torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||
```
|
||||
|
||||
|
||||
## Important resources & papers
|
||||
- VallE: https://arxiv.org/abs/2301.02111
|
||||
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
||||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||
OUT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run", "training")
|
||||
|
||||
# Training Parameters
|
||||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||
START_WITH_EVAL = True # if True it will star with evaluation
|
||||
BATCH_SIZE = 3 # set here the batch size
|
||||
GRAD_ACUMM_STEPS = 84 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
path="/raid/datasets/LJSpeech-1.1_24khz/",
|
||||
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
# Add here the configs of the datasets
|
||||
DATASETS_CONFIG_LIST = [config_dataset]
|
||||
|
||||
# Define the path where XTTS v1.1.1 files will be downloaded
|
||||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v1.1_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/mel_stats.pth"
|
||||
|
||||
# Set the path to the downloaded files
|
||||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, DVAE_CHECKPOINT_LINK.split("/")[-1])
|
||||
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, MEL_NORM_LINK.split("/")[-1])
|
||||
|
||||
# download DVAE files if needed
|
||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
|
||||
# Download XTTS v1.1 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json"
|
||||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth"
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, TOKENIZER_FILE_LINK.split("/")[-1]) # vocab.json file
|
||||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, XTTS_CHECKPOINT_LINK.split("/")[-1]) # model.pth file
|
||||
|
||||
# download XTTS v1.1 files if needed
|
||||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
||||
print(" > Downloading XTTS v1.1 files!")
|
||||
ModelManager._download_model_files([TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = (
|
||||
"./tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
)
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
def main():
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file=MEL_NORM_FILE,
|
||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
|
||||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file=TOKENIZER_FILE,
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint
|
||||
)
|
||||
# define audio config
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
)
|
||||
# training parameters config
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{
|
||||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(
|
||||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||
skip_train_epoch=False,
|
||||
start_with_eval=START_WITH_EVAL,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
|
@ -28,7 +28,7 @@ tokenizer, config = TTSTokenizer.init_from_config(config)
|
|||
|
||||
def test_acoustic_model():
|
||||
dummy_tokens = torch.rand((1, 41)).long().to(device)
|
||||
dummy_text_lens = torch.tensor([41]).to(device)
|
||||
dummy_text_lens = torch.tensor([41]).long().to(device)
|
||||
dummy_spec = torch.rand((1, 100, 207)).to(device)
|
||||
dummy_spec_lens = torch.tensor([207]).to(device)
|
||||
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
|
||||
|
@ -38,6 +38,7 @@ def test_acoustic_model():
|
|||
args.num_mels = 100
|
||||
|
||||
acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
|
||||
acoustic_model = acoustic_model.train()
|
||||
|
||||
output = acoustic_model(
|
||||
tokens=dummy_tokens,
|
||||
|
@ -51,16 +52,12 @@ def test_acoustic_model():
|
|||
speaker_idx=None,
|
||||
)
|
||||
assert list(output["model_outputs"].shape) == [1, 207, 100]
|
||||
output["model_outputs"].sum().backward()
|
||||
# output["model_outputs"].sum().backward()
|
||||
|
||||
|
||||
def test_hifi_decoder():
|
||||
dummy_input = torch.rand((1, 207, 100)).to(device)
|
||||
dummy_text_lens = torch.tensor([41]).to(device)
|
||||
dummy_spec = torch.rand((1, 100, 207)).to(device)
|
||||
dummy_spec_lens = torch.tensor([207]).to(device)
|
||||
dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
|
||||
dummy_energy = torch.rand((1, 1, 207)).long().to(device)
|
||||
|
||||
waveform_decoder = HifiganGenerator(
|
||||
100,
|
||||
|
@ -77,6 +74,7 @@ def test_hifi_decoder():
|
|||
conv_post_weight_norm=False,
|
||||
conv_post_bias=False,
|
||||
).to(device)
|
||||
waveform_decoder = waveform_decoder.train()
|
||||
|
||||
vocoder_input_slices, slice_ids = rand_segments( # pylint: disable=unused-variable
|
||||
x=dummy_input.transpose(1, 2),
|
||||
|
@ -88,4 +86,4 @@ def test_hifi_decoder():
|
|||
|
||||
outputs = waveform_decoder(x=vocoder_input_slices.detach())
|
||||
assert list(outputs.shape) == [1, 1, 8192]
|
||||
outputs.sum().backward()
|
||||
# outputs.sum().backward()
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from tests import get_tests_output_path
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
path="tests/data/ljspeech/",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_val="metadata.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_dataset]
|
||||
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
|
||||
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "train_outputs", "xtts_tests")
|
||||
os.makedirs(OUT_PATH, exist_ok=True)
|
||||
|
||||
# Create DVAE checkpoint and mel_norms on test time
|
||||
# DVAE parameters: For the training we need the dvae to extract the dvae tokens, given that you must provide the paths for this model
|
||||
DVAE_CHECKPOINT = os.path.join(OUT_PATH, "dvae.pth") # DVAE checkpoint
|
||||
MEL_NORM_FILE = os.path.join(
|
||||
OUT_PATH, "mel_stats.pth"
|
||||
) # Mel spectrogram norms, required for dvae mel spectrogram extraction
|
||||
dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
positional_dims=1,
|
||||
num_tokens=8192,
|
||||
codebook_dim=512,
|
||||
hidden_dim=512,
|
||||
num_resnet_blocks=3,
|
||||
kernel_size=3,
|
||||
num_layers=2,
|
||||
use_transposed_convs=False,
|
||||
)
|
||||
torch.save(dvae.state_dict(), DVAE_CHECKPOINT)
|
||||
mel_stats = torch.ones(80)
|
||||
torch.save(mel_stats, MEL_NORM_FILE)
|
||||
|
||||
|
||||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
||||
TOKENIZER_FILE = "tests/inputs/xtts_vocab.json" # vocab.json file
|
||||
XTTS_CHECKPOINT = None # "/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth" # model.pth file
|
||||
|
||||
|
||||
# Training sentences generations
|
||||
SPEAKER_REFERENCE = "tests/data/ljspeech/wavs/LJ001-0002.wav" # speaker reference to be used in training test sentences
|
||||
LANGUAGE = config_dataset.language
|
||||
|
||||
|
||||
# Training Parameters
|
||||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
||||
START_WITH_EVAL = False # if True it will star with evaluation
|
||||
BATCH_SIZE = 2 # set here the batch size
|
||||
GRAD_ACUMM_STEPS = 1 # set here the grad accumulation steps
|
||||
# Note: we recommend that BATCH_SIZE * GRAD_ACUMM_STEPS need to be at least 252 for more efficient training. You can increase/decrease BATCH_SIZE but then set GRAD_ACUMM_STEPS accordingly.
|
||||
|
||||
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file=MEL_NORM_FILE,
|
||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file=TOKENIZER_FILE,
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000
|
||||
)
|
||||
config = GPTTrainerConfig(
|
||||
epochs=1,
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": SPEAKER_REFERENCE,
|
||||
"language": LANGUAGE,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(
|
||||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
||||
skip_train_epoch=False,
|
||||
start_with_eval=True,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
# remove output path
|
||||
shutil.rmtree(OUT_PATH)
|
|
@ -99,6 +99,7 @@ def test_xtts_streaming():
|
|||
"""Testing the new inference_stream method"""
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||
config = XttsConfig()
|
||||
|
@ -115,7 +116,7 @@ def test_xtts_streaming():
|
|||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
speaker_embedding,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
|
Loading…
Reference in New Issue