Bug Fix on inference using XTTS trainer checkpoint

This commit is contained in:
Edresson Casanova 2023-10-18 09:42:00 -03:00
parent c4ceaabe2c
commit 9e3598c3b7
9 changed files with 419 additions and 371 deletions

View File

@ -197,6 +197,7 @@ class GPT(nn.Module):
if use_deepspeed:
import deepspeed
self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU
@ -451,7 +452,7 @@ class GPT(nn.Module):
if cond_idxs is not None:
for idx, r in enumerate(cond_idxs.squeeze()):
l = r[1] - r[0]
attn_mask_cond[idx, l : ] = 0.0
attn_mask_cond[idx, l:] = 0.0
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
@ -498,7 +499,7 @@ class GPT(nn.Module):
for idx, l in enumerate(code_lengths):
mel_targets[idx, l + 1 :] = -1
# check if stoptoken is in every row of mel_targets
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
0

View File

@ -1,13 +1,12 @@
import torch
import torchaudio
from torch import nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio
from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
self.cond_in_each_up_layer = cond_in_each_up_layer
# initial upsampling layers
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers
self.ups = nn.ModuleList()
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
# post convolution layer
self.conv_post = weight_norm(
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T]
"""
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), "replicate"
)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.forward(c)
def remove_weight_norm(self):
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
assert not self.training
self.remove_weight_norm()
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies.
"""
"""This is copied from 🐸TTS to remove it from the dependencies."""
# pylint: disable=W0102
def __init__(
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
return criterion, state["step"]
return criterion
class HifiDecoder(torch.nn.Module):
def __init__(
self,
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
"""
return self.forward(c, g=g)
def load_checkpoint(
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys
state = state["model"]

View File

@ -1,26 +1,27 @@
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
import copy
import inspect
import random
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from transformers import (
BeamSearchScorer,
ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
GenerationConfig,
GenerationMixin,
LogitsProcessorList,
StoppingCriteriaList,
DisjunctiveConstraint,
BeamSearchScorer,
PhrasalConstraint,
ConstrainedBeamSearchScorer,
PreTrainedModel,
StoppingCriteriaList,
)
import numpy as np
import random
import warnings
import inspect
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
import torch
from typing import Callable, List, Optional, Union
from torch import nn
import torch.distributed as dist
import copy
def setup_seed(seed):
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = False,
seed=0,
**kwargs,
@ -125,7 +124,7 @@ class NewGenerationMixin(GenerationMixin):
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
#setup_seed(seed)
# setup_seed(seed)
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = StreamGenerationConfig.from_model_config(
self.config
)
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if (
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if (
model_kwargs.get("attention_mask", None) is None
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
if not self.config.is_encoder_decoder:
if (
generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
> 0
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
UserWarning,
)
elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
elif (
not has_default_max_length and generation_config.max_new_tokens is not None
):
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
if (
generation_config.min_length is not None
and generation_config.min_length > generation_config.max_length
):
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = (
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
# 7. determine generation mode
is_constraint_gen_mode = (
generation_config.constraints is not None
or generation_config.force_words_ids is not None
generation_config.constraints is not None or generation_config.force_words_ids is not None
)
is_contrastive_search_gen_mode = (
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
)
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and generation_config.do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
)
elif is_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences,
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams
* generation_config.num_return_sequences,
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
elif is_group_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError(
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
)
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
has_default_typical_p = (
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
)
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
if not has_default_typical_p:
raise ValueError(
"Decoder argument `typical_p` is not supported with beam groups."
)
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
elif is_constraint_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError(
"`max_length` needs to be a stopping_criteria for now."
)
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if generation_config.num_beams <= 1:
raise ValueError(
"`num_beams` needs to be greater than 1 for constrained generation."
)
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
if generation_config.do_sample:
raise ValueError(
"`do_sample` needs to be false for constrained generation."
)
raise ValueError("`do_sample` needs to be false for constrained generation.")
if (
generation_config.num_beam_groups is not None
and generation_config.num_beam_groups > 1
):
raise ValueError(
"`num_beam_groups` not supported yet for constrained generation."
)
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if generation_config.constraints is not None:
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(
not isinstance(token_ids, list) for token_ids in word_ids
):
if any(not isinstance(token_ids, list) for token_ids in word_ids):
typeerror()
if any(
any(
(not isinstance(token_id, int) or token_id < 0)
for token_id in token_ids
)
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
):
typeerror()
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
else:
if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror()
if any(
(not isinstance(token_id, int) or token_id < 0)
for token_id in word_ids
):
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
typeerror()
constraint = PhrasalConstraint(word_ids)
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```"""
# init values
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(
stopping_criteria, max_length
)
logits_warper = (
logits_warper if logits_warper is not None else LogitsProcessorList()
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
)
# sample
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul(
(sum(next_tokens != i for i in eos_token_id)).long()
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
@ -1007,22 +885,17 @@ def init_stream_support():
if __name__ == "__main__":
from transformers import PreTrainedModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-560m", torch_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = model.to("cuda:0")
model = model.eval()
prompt_text = "hello? \n"
input_ids = tokenizer(
prompt_text, return_tensors="pt", add_special_tokens=False
).input_ids
input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
input_ids = input_ids.to("cuda:0")
with torch.no_grad():

View File

@ -1,16 +1,18 @@
import os
import random
import sys
import numpy as np
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
torch.set_num_threads(1)
def key_samples_by_col(samples, col):
"""Returns a dictionary of samples keyed by language."""
samples_by_col = {}
@ -23,11 +25,11 @@ def key_samples_by_col(samples, col):
return samples_by_col
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
rel_clip = load_audio(gt_path, sample_rate)
# if eval uses a middle size sample when it is possible to be more reproducible
if is_eval:
sample_length = int((min_sample_length + max_sample_length)/2)
sample_length = int((min_sample_length + max_sample_length) / 2)
else:
sample_length = random.randint(min_sample_length, max_sample_length)
gap = rel_clip.shape[-1] - sample_length
@ -41,7 +43,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
else:
rand_start = random.randint(0, gap)
rand_end = rand_start+sample_length
rand_end = rand_start + sample_length
rel_clip = rel_clip[:, rand_start:rand_end]
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
cond_idxs = [rand_start, rand_end]
@ -50,7 +52,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == '.mp3':
if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
@ -72,6 +74,7 @@ def load_audio(audiopath, sampling_rate):
audio.clip_(-1, 1)
return audio
class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config
@ -103,16 +106,18 @@ class XTTSDataset(torch.utils.data.Dataset):
print(" > Filtering invalid eval samples!!")
new_samples = []
for sample in self.samples:
try:
tseq, _, wav, _, _, _ = self.load_item(sample)
except:
pass
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
continue
new_samples.append(sample)
try:
tseq, _, wav, _, _, _ = self.load_item(sample)
except:
pass
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
continue
new_samples.append(sample)
self.samples = new_samples
print(" > Total eval samples after filtering:", len(self.samples))
@ -125,9 +130,9 @@ class XTTSDataset(torch.utils.data.Dataset):
return tokens
def load_item(self, sample):
text = str(sample['text'])
text = str(sample["text"])
tseq = self.get_text(text, sample["language"])
audiopath = sample['audio_file']
audiopath = sample["audio_file"]
wav = load_audio(audiopath, self.sample_rate)
if text is None or len(text.strip()) == 0:
raise ValueError
@ -136,7 +141,9 @@ class XTTSDataset(torch.utils.data.Dataset):
raise ValueError
# get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
cond, cond_len, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
)
return tseq, audiopath, wav, cond, cond_len, cond_idxs
@ -151,7 +158,7 @@ class XTTSDataset(torch.utils.data.Dataset):
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang+"_"+str(index)
sample_id = lang + "_" + str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
@ -170,26 +177,30 @@ class XTTSDataset(torch.utils.data.Dataset):
return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
if (
wav is None
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None:
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
print(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
)
self.failed_samples.add(sample_id)
return self[1]
res = {
# 'real_text': text,
'text': tseq,
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
'wav': wav,
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
'filenames': audiopath,
'conditioning': cond.unsqueeze(1),
'cond_lens': torch.tensor(cond_len, dtype=torch.long),
'cond_idxs': torch.tensor(cond_idxs),
"text": tseq,
"text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
"wav": wav,
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
"filenames": audiopath,
"conditioning": cond.unsqueeze(1),
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
"cond_idxs": torch.tensor(cond_idxs),
}
return res
@ -223,8 +234,8 @@ class XTTSDataset(torch.utils.data.Dataset):
for i in range(B):
text = batch["text"][i]
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
wav = batch['wav'][i]
wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav)
wav = batch["wav"][i]
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded
batch["padded_text"] = text_padded

View File

@ -1,37 +1,30 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torchaudio
import torch.nn as nn
import torchaudio
from coqpit import Coqpit
from torch.nn import functional as F
from torch.utils.data import DataLoader
import sys
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.base_tts import BaseTTS
from coqpit import Coqpit
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.datasets.dataset import TTSDataset
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
@dataclass
class GPTTrainerConfig(XttsConfig):
@ -42,6 +35,7 @@ class GPTTrainerConfig(XttsConfig):
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@ -55,27 +49,28 @@ class GPTArgs(XttsArgs):
gpt_loss_mel_ce_weight: float = 1.0
gpt_num_audio_tokens: int = 8194
debug_loading_failures: bool = False
max_wav_length: int = 255995 # ~11.6 seconds
max_wav_length: int = 255995 # ~11.6 seconds
max_text_length: int = 200
tokenizer_file: str = ""
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
dvae_checkpoint: str = ""
xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ('load', 'save')
assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__['local_model_path']:
if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None
return model_info
class GPTTrainer(BaseTTS):
def __init__(self, config: Coqpit):
"""
@ -89,18 +84,17 @@ class GPTTrainer(BaseTTS):
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
# init gpt encoder and hifigan decoder
self.xtts.init_models()
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
if self.args.xtts_checkpoint:
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu")
)
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!")
@ -113,10 +107,15 @@ class GPTTrainer(BaseTTS):
del gpt_checkpoint[key]
else:
del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
if (
"text_embedding.weight" in gpt_checkpoint
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
):
num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
)
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head)
@ -156,7 +155,7 @@ class GPTTrainer(BaseTTS):
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file
mel_norm_file=self.args.mel_norm_file,
)
# Load DVAE
@ -175,17 +174,18 @@ class GPTTrainer(BaseTTS):
self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu")
)
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
else:
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!")
raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
)
# Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
)
@property
def device(self):
@ -203,7 +203,9 @@ class GPTTrainer(BaseTTS):
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2)
"""
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
)
return losses
@torch.no_grad()
@ -215,7 +217,9 @@ class GPTTrainer(BaseTTS):
test_audios = {}
print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
wav = self.xtts.synthesize(
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
)["wav"]
test_audios["{}-audio".format(idx)] = wav
# delete inference layers
@ -231,7 +235,7 @@ class GPTTrainer(BaseTTS):
def format_batch(self, batch: Dict) -> Dict:
return batch
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
def format_batch_on_device(self, batch):
"""Compute spectrograms on the device."""
batch["text_lengths"] = batch["text_lengths"]
@ -241,10 +245,10 @@ class GPTTrainer(BaseTTS):
# compute conditioning mel specs
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
B, num_cond_samples, C, T = batch["conditioning"].size()
conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T)
conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T)
paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped)
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
T_mel = paired_conditioning_mel.size(2)
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
# get the conditioning embeddings
@ -300,6 +304,7 @@ class GPTTrainer(BaseTTS):
# ignore similarities.pth on clearml save/upload
if self.config.dashboard_logger.lower() == "clearml":
from clearml.binding.frameworks import WeightsFileHandler
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
@torch.no_grad()
@ -367,16 +372,23 @@ class GPTTrainer(BaseTTS):
return loader
def get_optimizer(self) -> List:
"""Initiate and return the optimizer based on the config parameters.
"""
"""Initiate and return the optimizer based on the config parameters."""
# ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model
# parameters to only GPT model
net = self.xtts.gpt
# normalizations
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
norm_modules = (
nn.BatchNorm2d,
nn.InstanceNorm2d,
nn.BatchNorm1d,
nn.InstanceNorm1d,
nn.BatchNorm3d,
nn.InstanceNorm3d,
nn.GroupNorm,
nn.LayerNorm,
)
# nn.Embedding
emb_modules = (nn.Embedding, nn.EmbeddingBag)
@ -390,7 +402,7 @@ class GPTTrainer(BaseTTS):
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
fpn = '%s.%s' % (mn, k) if mn else k # full param name
fpn = "%s.%s" % (mn, k) if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb:
@ -402,26 +414,26 @@ class GPTTrainer(BaseTTS):
params_weights = [param_map[k] for k in params_names_weights]
groups = [
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]},
{ 'params': params_notweights, 'weight_decay': 0}
{"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
{"params": params_notweights, "weight_decay": 0},
]
# torch.optim.AdamW
opt = get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
parameters=groups,
)
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
parameters=groups,
)
opt._group_names = [params_names_weights, params_names_notweights]
return opt
return get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
# optimize only for the GPT model
parameters=self.xtts.gpt.parameters(),
)
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
# optimize only for the GPT model
parameters=self.xtts.gpt.parameters(),
)
def get_scheduler(self, optimizer) -> List:
"""Set the scheduler for the optimizer.
@ -461,4 +473,3 @@ class GPTTrainer(BaseTTS):
Defaults to None.
"""
return GPTTrainer(config)

View File

@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
init_stream_support()
def load_audio(audiopath, sr=22050):
"""
Load an audio file from disk and resample it to the specified sampling rate.
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token,
)
if self.args.use_hifigan:
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
@ -414,21 +414,20 @@ class Xtts(BaseTTS):
return diffusion_latent
@torch.inference_mode()
def get_speaker_embedding(
self,
audio_path
):
def get_speaker_embedding(self, audio_path):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
audio.to(self.device), l2_norm=True
).unsqueeze(-1).to(self.device)
speaker_embedding = (
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)
return speaker_embedding
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
(
gpt_cond_latent,
diffusion_conditioning,
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
)
return self.inference(
text,
language,
@ -588,7 +585,7 @@ class Xtts(BaseTTS):
decoder=decoder,
**hf_generate_kwargs,
)
@torch.inference_mode()
def inference(
self,
@ -666,7 +663,7 @@ class Xtts(BaseTTS):
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
if decoder == "hifigan":
assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
@ -721,7 +718,9 @@ class Xtts(BaseTTS):
decoder="hifigan",
**hf_generate_kwargs,
):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -793,7 +792,7 @@ class Xtts(BaseTTS):
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_path=None,
vocab_path=None,
eval=True,
strict=True,
@ -827,6 +826,15 @@ class Xtts(BaseTTS):
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
coqui_trainer_checkpoint = True
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key
# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]

View File

@ -0,0 +1,145 @@
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
# Define here the dataset used
config_ljspeech = BaseDatasetConfig(
formatter="ljspeech",
dataset_name="ljspeech",
path="/raid/datasets/LJSpeech-1.1_24khz/",
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
language="en",
)
DATASETS_CONFIG_LIST = [config_ljspeech]
def freeze_layers(trainer):
pass
def main():
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
xtts_checkpoint="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth", # checkpoint path of the model that you want to fine-tune
tokenizer_file="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json",
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
)
config = GPTTrainerConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=10000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "en",
},
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "en",
},
{
"text": "Levei muito tempo para desenvolver uma voz e agora que a tenho não vou ficar calado .",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "pt",
},
],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(
restore_path=RESTORE_PATH,
skip_train_epoch=SKIP_TRAIN_EPOCH,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
callbacks={"on_epoch_start": freeze_layers},
)
trainer.fit()
if __name__ == "__main__":
RUN_NAME = "GPT_XTTS_LJSpeech_fixed"
PROJECT_NAME = "XTTS_trainer"
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_v1_FT/"
# DASHBOARD_LOGGER = "clearml"
# LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_v1/"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
RESTORE_PATH = None
SKIP_TRAIN_EPOCH = False
START_WITH_EVAL = True
BATCH_SIZE = 3
GRAD_ACUMM_STEPS = 28 * 3
# debug
# DASHBOARD_LOGGER = "tensorboard"
# LOGGER_URI = None
# RESTORE_PATH = None
# BATCH_SIZE = 2
# GRAD_ACUMM_STEPS = 1
main()

View File

@ -2,9 +2,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
formatter="coqui",
@ -252,32 +250,34 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [
config_coqui_MLS_metadata_test_with_previous_audio_key_de,
config_coqui_mls_italian_metadata_with_previous_audio_key_it,
]
def freeze_layers(trainer):
pass
def main():
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, # GPT SR
dvae_sample_rate=22050,
diffusion_sample_rate=24000,
output_sample_rate=24000
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
)
config = GPTTrainerConfig(
output_path=OUT_PATH,
@ -303,20 +303,26 @@ def main():
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=True,
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
# lr=1e-4, # learning rate
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
]
{
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"speaker_wav": "/raid/edresson/dev/ref.wav",
"language": "en",
},
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": "/raid/edresson/dev/ref.wav",
"language": "en",
},
],
)
# init the model from config
@ -332,13 +338,18 @@ def main():
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS),
TrainerArgs(
restore_path=RESTORE_PATH,
skip_train_epoch=SKIP_TRAIN_EPOCH,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
callbacks={"on_epoch_start": freeze_layers}
callbacks={"on_epoch_start": freeze_layers},
)
trainer.fit()
@ -351,17 +362,15 @@ if __name__ == "__main__":
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
RESTORE_PATH = None
SKIP_TRAIN_EPOCH = False
START_WITH_EVAL = True
START_WITH_EVAL = True
BATCH_SIZE = 9
GRAD_ACUMM_STEPS = 28
# debug
# DASHBOARD_LOGGER = "tensorboard"
# LOGGER_URI = None
# LOGGER_URI = None
# RESTORE_PATH = None
BATCH_SIZE = 2
GRAD_ACUMM_STEPS = 1
main()

View File

@ -99,6 +99,7 @@ def test_xtts_streaming():
"""Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig()
@ -115,7 +116,7 @@ def test_xtts_streaming():
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding
speaker_embedding,
)
wav_chuncks = []
for i, chunk in enumerate(chunks):