From 9e3598c3b714db5610728635bf5a1dc170a4dd21 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 18 Oct 2023 09:42:00 -0300 Subject: [PATCH] Bug Fix on inference using XTTS trainer checkpoint --- TTS/tts/layers/xtts/gpt.py | 5 +- TTS/tts/layers/xtts/hifigan_decoder.py | 29 +-- TTS/tts/layers/xtts/stream_generator.py | 271 ++++++--------------- TTS/tts/layers/xtts/trainer/dataset.py | 79 +++--- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 143 ++++++----- TTS/tts/models/xtts.py | 48 ++-- recipes/ljspeech/xtts_v1/train_xtts.py | 145 +++++++++++ recipes/multilingual/xtts_v1/train_xtts.py | 67 ++--- tests/zoo_tests/test_models.py | 3 +- 9 files changed, 419 insertions(+), 371 deletions(-) create mode 100644 recipes/ljspeech/xtts_v1/train_xtts.py diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 52086f13..8f24ac01 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -197,6 +197,7 @@ class GPT(nn.Module): if use_deepspeed: import deepspeed + self.ds_engine = deepspeed.init_inference( model=self.gpt_inference.half(), # Transformers models mp_size=1, # Number of GPU @@ -451,7 +452,7 @@ class GPT(nn.Module): if cond_idxs is not None: for idx, r in enumerate(cond_idxs.squeeze()): l = r[1] - r[0] - attn_mask_cond[idx, l : ] = 0.0 + attn_mask_cond[idx, l:] = 0.0 for idx, l in enumerate(text_lengths): attn_mask_text[idx, l + 1 :] = 0.0 @@ -498,7 +499,7 @@ class GPT(nn.Module): for idx, l in enumerate(code_lengths): mel_targets[idx, l + 1 :] = -1 - + # check if stoptoken is in every row of mel_targets assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ 0 diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 6439b455..5fcff870 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,13 +1,12 @@ import torch +import torchaudio from torch import nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm -import torchaudio from TTS.utils.io import load_fsspec - LRELU_SLOPE = 0.1 @@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module): self.cond_in_each_up_layer = cond_in_each_up_layer # initial upsampling layers - self.conv_pre = weight_norm( - Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - ) + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) resblock = ResBlock1 if resblock_type == "1" else ResBlock2 # upsampling layers self.ups = nn.ModuleList() @@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes) - ): + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) # post convolution layer - self.conv_post = weight_norm( - Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias) - ) + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) if cond_channels > 0: self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) @@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module): Tensor: [B, 1, T] """ c = c.to(self.conv_pre.weight.device) - c = torch.nn.functional.pad( - c, (self.inference_padding, self.inference_padding), "replicate" - ) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.forward(c) def remove_weight_norm(self): @@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module): assert not self.training self.remove_weight_norm() + class SELayer(nn.Module): def __init__(self, channel, reduction=8): super(SELayer, self).__init__() @@ -425,10 +417,8 @@ class PreEmphasis(nn.Module): return torch.nn.functional.conv1d(x, self.filter).squeeze(1) - class ResNetSpeakerEncoder(nn.Module): - """This is copied from 🐸TTS to remove it from the dependencies. - """ + """This is copied from 🐸TTS to remove it from the dependencies.""" # pylint: disable=W0102 def __init__( @@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module): return criterion, state["step"] return criterion + class HifiDecoder(torch.nn.Module): def __init__( self, @@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module): """ return self.forward(c, g=g) - def load_checkpoint( - self, checkpoint_path, eval=False - ): # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) # remove unused keys state = state["model"] diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 8bdd2291..e12f8995 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -1,26 +1,27 @@ # Adapted from: https://github.com/LowinLi/transformers-stream-generator +import copy +import inspect +import random +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import nn from transformers import ( + BeamSearchScorer, + ConstrainedBeamSearchScorer, + DisjunctiveConstraint, GenerationConfig, GenerationMixin, LogitsProcessorList, - StoppingCriteriaList, - DisjunctiveConstraint, - BeamSearchScorer, PhrasalConstraint, - ConstrainedBeamSearchScorer, PreTrainedModel, + StoppingCriteriaList, ) -import numpy as np -import random -import warnings -import inspect from transformers.generation.utils import GenerateOutput, SampleOutput, logger -import torch -from typing import Callable, List, Optional, Union -from torch import nn -import torch.distributed as dist -import copy def setup_seed(seed): @@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin): generation_config: Optional[StreamGenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = False, seed=0, **kwargs, @@ -125,7 +124,7 @@ class NewGenerationMixin(GenerationMixin): - [`~generation.BeamSearchEncoderDecoderOutput`], - [`~generation.BeamSampleEncoderDecoderOutput`] """ - #setup_seed(seed) + # setup_seed(seed) # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() @@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin): # legacy: users may modify the model configuration to control generation -- update the generation config # model attribute accordingly, if it was created from the model config if self.generation_config._from_model_config: - new_generation_config = StreamGenerationConfig.from_model_config( - self.config - ) + new_generation_config = StreamGenerationConfig.from_model_config(self.config) if new_generation_config != self.generation_config: warnings.warn( "You have modified the pretrained model configuration to control generation. This is a" @@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin): generation_config = self.generation_config generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update( - **kwargs - ) # All unused kwargs must be model kwargs + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs # self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if ( - generation_config.pad_token_id is None - and generation_config.eos_token_id is not None - ): + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: if model_kwargs.get("attention_mask", None) is None: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " @@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin): eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] - logger.warning( - f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." - ) + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") generation_config.pad_token_id = eos_token_id # 3. Define model inputs @@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin): model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache - accepts_attention_mask = "attention_mask" in set( - inspect.signature(self.forward).parameters.keys() - ) + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs - if ( - model_kwargs.get("attention_mask", None) is None - and requires_attention_mask - and accepts_attention_mask - ): - model_kwargs[ - "attention_mask" - ] = self._prepare_attention_mask_for_generation( + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id, @@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin): if not self.config.is_encoder_decoder: if ( generation_config.pad_token_id is not None - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) - > 0 + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " @@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin): # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] - has_default_max_length = ( - kwargs.get("max_length") is None - and generation_config.max_length is not None - ) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" @@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin): UserWarning, ) elif has_default_max_length and generation_config.max_new_tokens is not None: - generation_config.max_length = ( - generation_config.max_new_tokens + input_ids_seq_length - ) - elif ( - not has_default_max_length and generation_config.max_new_tokens is not None - ): + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: raise ValueError( "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" " limit to the generated output length. Remove one of those arguments. Please refer to the" @@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin): "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" ) - if ( - generation_config.min_length is not None - and generation_config.min_length > generation_config.max_length - ): + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" f" the maximum length ({generation_config.max_length})" ) if input_ids_seq_length >= generation_config.max_length: - input_ids_string = ( - "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - ) + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" logger.warning( f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" @@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin): # 7. determine generation mode is_constraint_gen_mode = ( - generation_config.constraints is not None - or generation_config.force_words_ids is not None + generation_config.constraints is not None or generation_config.force_words_ids is not None ) is_contrastive_search_gen_mode = ( @@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin): ) if generation_config.num_beam_groups > generation_config.num_beams: - raise ValueError( - "`num_beam_groups` has to be smaller or equal to `num_beams`" - ) + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and generation_config.do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." @@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin): ) elif is_beam_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) + raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( @@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin): logits_warper = self._get_logits_warper(generation_config) if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) + raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size * generation_config.num_return_sequences, @@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin): # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, - expand_size=generation_config.num_beams - * generation_config.num_return_sequences, + expand_size=generation_config.num_beams * generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin): elif is_group_beam_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if generation_config.num_beams % generation_config.num_beam_groups != 0: - raise ValueError( - "`num_beams` should be divisible by `num_beam_groups` for group beam search." - ) + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) + raise ValueError("`max_length` needs to be a stopping_criteria for now.") - has_default_typical_p = ( - kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 - ) + has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 if not has_default_typical_p: - raise ValueError( - "Decoder argument `typical_p` is not supported with beam groups." - ) + raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( @@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin): elif is_constraint_gen_mode: if generation_config.num_return_sequences > generation_config.num_beams: - raise ValueError( - "`num_return_sequences` has to be smaller or equal to `num_beams`." - ) + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: - raise ValueError( - "`max_length` needs to be a stopping_criteria for now." - ) + raise ValueError("`max_length` needs to be a stopping_criteria for now.") if generation_config.num_beams <= 1: - raise ValueError( - "`num_beams` needs to be greater than 1 for constrained generation." - ) + raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") if generation_config.do_sample: - raise ValueError( - "`do_sample` needs to be false for constrained generation." - ) + raise ValueError("`do_sample` needs to be false for constrained generation.") - if ( - generation_config.num_beam_groups is not None - and generation_config.num_beam_groups > 1 - ): - raise ValueError( - "`num_beam_groups` not supported yet for constrained generation." - ) + if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: + raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] if generation_config.constraints is not None: @@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin): if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any( - not isinstance(token_ids, list) for token_ids in word_ids - ): + if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( - any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in token_ids - ) + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() @@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin): else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() - if any( - (not isinstance(token_id, int) or token_id < 0) - for token_id in word_ids - ): + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) @@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin): ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - logits_warper = ( - logits_warper if logits_warper is not None else LogitsProcessorList() - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - eos_token_id = ( - eos_token_id - if eos_token_id is not None - else self.generation_config.eos_token_id - ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions + output_attentions if output_attentions is not None else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin): # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin): if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor( - 0.0 if this_peer_finished else 1.0 - ).to(input_ids.device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then @@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin): scores += (next_token_scores,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) + (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample @@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin): # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin): # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = unfinished_sequences.mul( - (sum(next_tokens != i for i in eos_token_id)).long() - ) + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): @@ -1007,22 +885,17 @@ def init_stream_support(): if __name__ == "__main__": - from transformers import PreTrainedModel - from transformers import AutoTokenizer, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel PreTrainedModel.generate = NewGenerationMixin.generate PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream - model = AutoModelForCausalLM.from_pretrained( - "bigscience/bloom-560m", torch_dtype=torch.float16 - ) + model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") model = model.to("cuda:0") model = model.eval() prompt_text = "hello? \n" - input_ids = tokenizer( - prompt_text, return_tensors="pt", add_special_tokens=False - ).input_ids + input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids input_ids = input_ids.to("cuda:0") with torch.no_grad(): diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index b122fc8a..41401fd6 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -1,16 +1,18 @@ import os import random import sys -import numpy as np +import numpy as np import torch import torch.nn.functional as F import torch.utils.data import torchaudio -from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load +from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load + torch.set_num_threads(1) + def key_samples_by_col(samples, col): """Returns a dictionary of samples keyed by language.""" samples_by_col = {} @@ -23,11 +25,11 @@ def key_samples_by_col(samples, col): return samples_by_col -def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False): +def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False): rel_clip = load_audio(gt_path, sample_rate) # if eval uses a middle size sample when it is possible to be more reproducible if is_eval: - sample_length = int((min_sample_length + max_sample_length)/2) + sample_length = int((min_sample_length + max_sample_length) / 2) else: sample_length = random.randint(min_sample_length, max_sample_length) gap = rel_clip.shape[-1] - sample_length @@ -41,7 +43,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, else: rand_start = random.randint(0, gap) - rand_end = rand_start+sample_length + rand_end = rand_start + sample_length rel_clip = rel_clip[:, rand_start:rand_end] rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) cond_idxs = [rand_start, rand_end] @@ -50,7 +52,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, def load_audio(audiopath, sampling_rate): # better load setting following: https://github.com/faroit/python_audio_loading_benchmark - if audiopath[-4:] == '.mp3': + if audiopath[-4:] == ".mp3": # it uses torchaudio with sox backend to load mp3 audio, lsr = torchaudio_sox_load(audiopath) else: @@ -72,6 +74,7 @@ def load_audio(audiopath, sampling_rate): audio.clip_(-1, 1) return audio + class XTTSDataset(torch.utils.data.Dataset): def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): self.config = config @@ -103,16 +106,18 @@ class XTTSDataset(torch.utils.data.Dataset): print(" > Filtering invalid eval samples!!") new_samples = [] for sample in self.samples: - try: - tseq, _, wav, _, _, _ = self.load_item(sample) - except: - pass - # Basically, this audio file is nonexistent or too long to be supported by the dataset. - if wav is None or \ - (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): - continue - new_samples.append(sample) + try: + tseq, _, wav, _, _, _ = self.load_item(sample) + except: + pass + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): + continue + new_samples.append(sample) self.samples = new_samples print(" > Total eval samples after filtering:", len(self.samples)) @@ -125,9 +130,9 @@ class XTTSDataset(torch.utils.data.Dataset): return tokens def load_item(self, sample): - text = str(sample['text']) + text = str(sample["text"]) tseq = self.get_text(text, sample["language"]) - audiopath = sample['audio_file'] + audiopath = sample["audio_file"] wav = load_audio(audiopath, self.sample_rate) if text is None or len(text.strip()) == 0: raise ValueError @@ -136,7 +141,9 @@ class XTTSDataset(torch.utils.data.Dataset): raise ValueError # get a slice from GT to condition the model - cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval) + cond, cond_len, cond_idxs = get_prompt_slice( + audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) return tseq, audiopath, wav, cond, cond_len, cond_idxs @@ -151,7 +158,7 @@ class XTTSDataset(torch.utils.data.Dataset): index = random.randint(0, len(self.samples[lang]) - 1) sample = self.samples[lang][index] # a unique id for each sampel to deal with fails - sample_id = lang+"_"+str(index) + sample_id = lang + "_" + str(index) # ignore samples that we already know that is not valid ones if sample_id in self.failed_samples: @@ -170,26 +177,30 @@ class XTTSDataset(torch.utils.data.Dataset): return self[1] # check if the audio and text size limits and if it out of the limits, added it failed_samples - if wav is None or \ - (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ - (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): # Basically, this audio file is nonexistent or too long to be supported by the dataset. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. if self.debug_failures and wav is not None and tseq is not None: - print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") + print( + f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" + ) self.failed_samples.add(sample_id) return self[1] res = { # 'real_text': text, - 'text': tseq, - 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), - 'wav': wav, - 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), - 'filenames': audiopath, - 'conditioning': cond.unsqueeze(1), - 'cond_lens': torch.tensor(cond_len, dtype=torch.long), - 'cond_idxs': torch.tensor(cond_idxs), + "text": tseq, + "text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long), + "wav": wav, + "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), + "filenames": audiopath, + "conditioning": cond.unsqueeze(1), + "cond_lens": torch.tensor(cond_len, dtype=torch.long), + "cond_idxs": torch.tensor(cond_idxs), } return res @@ -223,8 +234,8 @@ class XTTSDataset(torch.utils.data.Dataset): for i in range(B): text = batch["text"][i] text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text) - wav = batch['wav'][i] - wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav) + wav = batch["wav"][i] + wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav) batch["wav"] = wav_padded batch["padded_text"] = text_padded diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 87b1228e..e4df2b90 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -1,37 +1,30 @@ import os +import sys from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional, Tuple, Union import torch -import torchaudio import torch.nn as nn +import torchaudio +from coqpit import Coqpit from torch.nn import functional as F from torch.utils.data import DataLoader -import sys - - -from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.xtts.gpt import GPT -from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts -from TTS.tts.configs.xtts_config import XttsConfig - -from TTS.tts.models.base_tts import BaseTTS -from coqpit import Coqpit - -from TTS.tts.configs.tortoise_config import TortoiseConfig -from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram - -from TTS.tts.datasets.dataset import TTSDataset - from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler +from TTS.tts.configs.tortoise_config import TortoiseConfig +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.xtts.dvae import DiscreteVAE +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.utils.io import load_fsspec -from TTS.tts.layers.xtts.dvae import DiscreteVAE - -from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder @dataclass class GPTTrainerConfig(XttsConfig): @@ -42,6 +35,7 @@ class GPTTrainerConfig(XttsConfig): weighted_loss_multipliers: dict = field(default_factory=lambda: {}) test_sentences: List[dict] = field(default_factory=lambda: []) + @dataclass class XttsAudioConfig(XttsAudioConfig): dvae_sample_rate: int = 22050 @@ -55,27 +49,28 @@ class GPTArgs(XttsArgs): gpt_loss_mel_ce_weight: float = 1.0 gpt_num_audio_tokens: int = 8194 debug_loading_failures: bool = False - max_wav_length: int = 255995 # ~11.6 seconds + max_wav_length: int = 255995 # ~11.6 seconds max_text_length: int = 200 tokenizer_file: str = "" mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" dvae_checkpoint: str = "" xtts_checkpoint: str = "" - gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model - vocoder: str = "" # overide vocoder key on the config to avoid json write issues + gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model + vocoder: str = "" # overide vocoder key on the config to avoid json write issues def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size - assert operation_type in ('load', 'save') + assert operation_type in ("load", "save") # print(operation_type, model_info.__dict__) - if "similarities.pth" in model_info.__dict__['local_model_path']: + if "similarities.pth" in model_info.__dict__["local_model_path"]: return None return model_info + class GPTTrainer(BaseTTS): def __init__(self, config: Coqpit): """ @@ -89,18 +84,17 @@ class GPTTrainer(BaseTTS): self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) # init gpt encoder and hifigan decoder self.xtts.init_models() - # set mel stats - if self.args.mel_norm_file: - self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) if self.args.xtts_checkpoint: self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) + # set mel stats + if self.args.mel_norm_file: + self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) + # load GPT if available if self.args.gpt_checkpoint: - gpt_checkpoint = torch.load( - self.args.gpt_checkpoint, map_location=torch.device("cpu") - ) + gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): print("Coqui Trainer checkpoint detected! Converting it!") @@ -113,10 +107,15 @@ class GPTTrainer(BaseTTS): del gpt_checkpoint[key] else: del gpt_checkpoint[key] - + # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible - if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape: - num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + if ( + "text_embedding.weight" in gpt_checkpoint + and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape + ): + num_new_tokens = ( + self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + ) print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") # add new tokens to a linear layer (text_head) @@ -156,7 +155,7 @@ class GPTTrainer(BaseTTS): mel_fmin=0, mel_fmax=8000, n_mel_channels=80, - mel_norm_file=self.args.mel_norm_file + mel_norm_file=self.args.mel_norm_file, ) # Load DVAE @@ -175,17 +174,18 @@ class GPTTrainer(BaseTTS): self.dvae.eval() if self.args.dvae_checkpoint: - dvae_checkpoint = torch.load( - self.args.dvae_checkpoint, map_location=torch.device("cpu") - ) + dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) self.dvae.load_state_dict(dvae_checkpoint, strict=False) print(">> DVAE weights restored from:", self.args.dvae_checkpoint) else: - raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!") + raise RuntimeError( + "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" + ) # Mel spectrogram extractor for DVAE - self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate) - + self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( + mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate + ) @property def device(self): @@ -203,7 +203,9 @@ class GPTTrainer(BaseTTS): cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2) """ - losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) + losses = self.xtts.gpt( + text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs + ) return losses @torch.no_grad() @@ -215,7 +217,9 @@ class GPTTrainer(BaseTTS): test_audios = {} print(" | > Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): - wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"] + wav = self.xtts.synthesize( + s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3 + )["wav"] test_audios["{}-audio".format(idx)] = wav # delete inference layers @@ -231,7 +235,7 @@ class GPTTrainer(BaseTTS): def format_batch(self, batch: Dict) -> Dict: return batch - @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction + @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction def format_batch_on_device(self, batch): """Compute spectrograms on the device.""" batch["text_lengths"] = batch["text_lengths"] @@ -241,10 +245,10 @@ class GPTTrainer(BaseTTS): # compute conditioning mel specs # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor B, num_cond_samples, C, T = batch["conditioning"].size() - conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T) + conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T) paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) - n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) + n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) T_mel = paired_conditioning_mel.size(2) paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) # get the conditioning embeddings @@ -300,6 +304,7 @@ class GPTTrainer(BaseTTS): # ignore similarities.pth on clearml save/upload if self.config.dashboard_logger.lower() == "clearml": from clearml.binding.frameworks import WeightsFileHandler + WeightsFileHandler.add_pre_callback(callback_clearml_load_save) @torch.no_grad() @@ -367,16 +372,23 @@ class GPTTrainer(BaseTTS): return loader def get_optimizer(self) -> List: - """Initiate and return the optimizer based on the config parameters. - """ + """Initiate and return the optimizer based on the config parameters.""" # ToDo: deal with multi GPU training if self.config.optimizer_wd_only_on_weights: - # parameters to only GPT model + # parameters to only GPT model net = self.xtts.gpt # normalizations - norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, - nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) + norm_modules = ( + nn.BatchNorm2d, + nn.InstanceNorm2d, + nn.BatchNorm1d, + nn.InstanceNorm1d, + nn.BatchNorm3d, + nn.InstanceNorm3d, + nn.GroupNorm, + nn.LayerNorm, + ) # nn.Embedding emb_modules = (nn.Embedding, nn.EmbeddingBag) @@ -390,7 +402,7 @@ class GPTTrainer(BaseTTS): v.is_norm = isinstance(m, norm_modules) v.is_emb = isinstance(m, emb_modules) - fpn = '%s.%s' % (mn, k) if mn else k # full param name + fpn = "%s.%s" % (mn, k) if mn else k # full param name all_param_names.add(fpn) param_map[fpn] = v if v.is_bias or v.is_norm or v.is_emb: @@ -402,26 +414,26 @@ class GPTTrainer(BaseTTS): params_weights = [param_map[k] for k in params_names_weights] groups = [ - { 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]}, - { 'params': params_notweights, 'weight_decay': 0} + {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]}, + {"params": params_notweights, "weight_decay": 0}, ] # torch.optim.AdamW opt = get_optimizer( - self.config.optimizer, - self.config.optimizer_params, - self.config.lr, - parameters=groups, - ) + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + parameters=groups, + ) opt._group_names = [params_names_weights, params_names_notweights] return opt return get_optimizer( - self.config.optimizer, - self.config.optimizer_params, - self.config.lr, - # optimize only for the GPT model - parameters=self.xtts.gpt.parameters(), - ) + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + # optimize only for the GPT model + parameters=self.xtts.gpt.parameters(), + ) def get_scheduler(self, optimizer) -> List: """Set the scheduler for the optimizer. @@ -461,4 +473,3 @@ class GPTTrainer(BaseTTS): Defaults to None. """ return GPTTrainer(config) - \ No newline at end of file diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 3e609799..e2c8ca4c 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.xtts.gpt import GPT -from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.xtts.vocoder import UnivNetGenerator from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.vocoder import UnivNetGenerator from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec init_stream_support() + def load_audio(audiopath, sr=22050): """ Load an audio file from disk and resample it to the specified sampling rate. @@ -332,7 +333,6 @@ class Xtts(BaseTTS): stop_audio_token=self.args.gpt_stop_audio_token, ) - if self.args.use_hifigan: self.hifigan_decoder = HifiDecoder( input_sample_rate=self.args.input_sample_rate, @@ -414,21 +414,20 @@ class Xtts(BaseTTS): return diffusion_latent @torch.inference_mode() - def get_speaker_embedding( - self, - audio_path - ): + def get_speaker_embedding(self, audio_path): audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) - speaker_embedding = self.hifigan_decoder.speaker_encoder.forward( - audio.to(self.device), l2_norm=True - ).unsqueeze(-1).to(self.device) + speaker_embedding = ( + self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True) + .unsqueeze(-1) + .to(self.device) + ) return speaker_embedding def get_conditioning_latents( self, audio_path, gpt_cond_len=3, - ): + ): speaker_embedding = None diffusion_cond_latents = None if self.args.use_hifigan: @@ -563,11 +562,9 @@ class Xtts(BaseTTS): Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ - ( - gpt_cond_latent, - diffusion_conditioning, - speaker_embedding - ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) + (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents( + audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len + ) return self.inference( text, language, @@ -588,7 +585,7 @@ class Xtts(BaseTTS): decoder=decoder, **hf_generate_kwargs, ) - + @torch.inference_mode() def inference( self, @@ -666,7 +663,7 @@ class Xtts(BaseTTS): if ctokens > 8: gpt_latents = gpt_latents[:, :k] break - + if decoder == "hifigan": assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) @@ -721,7 +718,9 @@ class Xtts(BaseTTS): decoder="hifigan", **hf_generate_kwargs, ): - assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." + assert hasattr( + self, "hifigan_decoder" + ), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." text = f"[{language}]{text.strip().lower()}" text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -793,7 +792,7 @@ class Xtts(BaseTTS): self, config, checkpoint_dir=None, - checkpoint_path=None, + checkpoint_path=None, vocab_path=None, eval=True, strict=True, @@ -827,6 +826,15 @@ class Xtts(BaseTTS): ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"] for key in list(checkpoint.keys()): + # check if it is from the coqui Trainer if so convert it + if key.startswith("xtts."): + coqui_trainer_checkpoint = True + new_key = key.replace("xtts.", "") + checkpoint[new_key] = checkpoint[key] + del checkpoint[key] + key = new_key + + # remove unused keys if key.split(".")[0] in ignore_keys: del checkpoint[key] diff --git a/recipes/ljspeech/xtts_v1/train_xtts.py b/recipes/ljspeech/xtts_v1/train_xtts.py new file mode 100644 index 00000000..6c07053b --- /dev/null +++ b/recipes/ljspeech/xtts_v1/train_xtts.py @@ -0,0 +1,145 @@ +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig + +# Define here the dataset used +config_ljspeech = BaseDatasetConfig( + formatter="ljspeech", + dataset_name="ljspeech", + path="/raid/datasets/LJSpeech-1.1_24khz/", + meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv", + language="en", +) + +DATASETS_CONFIG_LIST = [config_ljspeech] + + +def freeze_layers(trainer): + pass + + +def main(): + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=255995, # ~11.6 seconds + max_text_length=200, + mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth", + dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth", + # tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune + # xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", + xtts_checkpoint="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth", # checkpoint path of the model that you want to fine-tune + tokenizer_file="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json", + gpt_num_audio_tokens=8194, + gpt_start_audio_token=8192, + gpt_stop_audio_token=8193, + ) + audio_config = XttsAudioConfig( + sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR + ) + config = GPTTrainerConfig( + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=1000, + save_step=10000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[ + { + "text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav", + "language": "en", + }, + { + "text": "This cake is great. It's so delicious and moist.", + "speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav", + "language": "en", + }, + { + "text": "Levei muito tempo para desenvolver uma voz e agora que a tenho não vou ficar calado .", + "speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav", + "language": "pt", + }, + ], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=RESTORE_PATH, + skip_train_epoch=SKIP_TRAIN_EPOCH, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + callbacks={"on_epoch_start": freeze_layers}, + ) + trainer.fit() + + +if __name__ == "__main__": + RUN_NAME = "GPT_XTTS_LJSpeech_fixed" + PROJECT_NAME = "XTTS_trainer" + OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_v1_FT/" + # DASHBOARD_LOGGER = "clearml" + # LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_v1/" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + RESTORE_PATH = None + SKIP_TRAIN_EPOCH = False + START_WITH_EVAL = True + BATCH_SIZE = 3 + GRAD_ACUMM_STEPS = 28 * 3 + + # debug + # DASHBOARD_LOGGER = "tensorboard" + # LOGGER_URI = None + # RESTORE_PATH = None + # BATCH_SIZE = 2 + # GRAD_ACUMM_STEPS = 1 + + main() diff --git a/recipes/multilingual/xtts_v1/train_xtts.py b/recipes/multilingual/xtts_v1/train_xtts.py index f36bf1ae..fa13d8d4 100644 --- a/recipes/multilingual/xtts_v1/train_xtts.py +++ b/recipes/multilingual/xtts_v1/train_xtts.py @@ -2,9 +2,7 @@ from trainer import Trainer, TrainerArgs from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples - -from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig - +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig( formatter="coqui", @@ -252,32 +250,34 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig( # DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it] -DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it] - +DATASETS_CONFIG_LIST = [ + config_coqui_MLS_metadata_test_with_previous_audio_key_de, + config_coqui_mls_italian_metadata_with_previous_audio_key_it, +] + + def freeze_layers(trainer): pass + def main(): # init args and config model_args = GPTArgs( - max_conditioning_length=132300, # 6 secs - min_conditioning_length=66150, # 3 secs + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs debug_loading_failures=False, - max_wav_length=255995, # ~11.6 seconds + max_wav_length=255995, # ~11.6 seconds max_text_length=200, mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth", dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth", - tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune - xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune + tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune + xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, ) audio_config = XttsAudioConfig( - sample_rate=22050, # GPT SR - dvae_sample_rate=22050, - diffusion_sample_rate=24000, - output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR ) config = GPTTrainerConfig( output_path=OUT_PATH, @@ -303,20 +303,26 @@ def main(): save_checkpoints=True, # target_loss="loss", print_eval=False, - # Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. optimizer="AdamW", - optimizer_wd_only_on_weights=True, - optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2}, - lr=5e-06, # learning rate - # lr=1e-4, # learning rate - # ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999 + optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate lr_scheduler="MultiStepLR", # it was adjusted accordly for the new step scheme lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, test_sentences=[ - {"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, - {"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, - ] + { + "text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "speaker_wav": "/raid/edresson/dev/ref.wav", + "language": "en", + }, + { + "text": "This cake is great. It's so delicious and moist.", + "speaker_wav": "/raid/edresson/dev/ref.wav", + "language": "en", + }, + ], ) # init the model from config @@ -332,13 +338,18 @@ def main(): # init the trainer and 🚀 trainer = Trainer( - TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS), + TrainerArgs( + restore_path=RESTORE_PATH, + skip_train_epoch=SKIP_TRAIN_EPOCH, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), config, output_path=OUT_PATH, model=model, train_samples=train_samples, eval_samples=eval_samples, - callbacks={"on_epoch_start": freeze_layers} + callbacks={"on_epoch_start": freeze_layers}, ) trainer.fit() @@ -351,17 +362,15 @@ if __name__ == "__main__": LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/" RESTORE_PATH = None SKIP_TRAIN_EPOCH = False - START_WITH_EVAL = True + START_WITH_EVAL = True BATCH_SIZE = 9 GRAD_ACUMM_STEPS = 28 # debug # DASHBOARD_LOGGER = "tensorboard" - # LOGGER_URI = None + # LOGGER_URI = None # RESTORE_PATH = None BATCH_SIZE = 2 GRAD_ACUMM_STEPS = 1 - - main() diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index db144f1c..7194ed5c 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -99,6 +99,7 @@ def test_xtts_streaming(): """Testing the new inference_stream method""" from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts + speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") config = XttsConfig() @@ -115,7 +116,7 @@ def test_xtts_streaming(): "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", "en", gpt_cond_latent, - speaker_embedding + speaker_embedding, ) wav_chuncks = [] for i, chunk in enumerate(chunks):