Bug Fix on inference using XTTS trainer checkpoint

This commit is contained in:
Edresson Casanova 2023-10-18 09:42:00 -03:00
parent c4ceaabe2c
commit 9e3598c3b7
9 changed files with 419 additions and 371 deletions

View File

@ -197,6 +197,7 @@ class GPT(nn.Module):
if use_deepspeed: if use_deepspeed:
import deepspeed import deepspeed
self.ds_engine = deepspeed.init_inference( self.ds_engine = deepspeed.init_inference(
model=self.gpt_inference.half(), # Transformers models model=self.gpt_inference.half(), # Transformers models
mp_size=1, # Number of GPU mp_size=1, # Number of GPU

View File

@ -1,13 +1,12 @@
import torch import torch
import torchaudio
from torch import nn from torch import nn
from torch.nn import Conv1d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm from torch.nn.utils import remove_weight_norm, weight_norm
import torchaudio
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
self.cond_in_each_up_layer = cond_in_each_up_layer self.cond_in_each_up_layer = cond_in_each_up_layer
# initial upsampling layers # initial upsampling layers
self.conv_pre = weight_norm( self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock1 if resblock_type == "1" else ResBlock2 resblock = ResBlock1 if resblock_type == "1" else ResBlock2
# upsampling layers # upsampling layers
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate( for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d)) self.resblocks.append(resblock(ch, k, d))
# post convolution layer # post convolution layer
self.conv_post = weight_norm( self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
)
if cond_channels > 0: if cond_channels > 0:
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
Tensor: [B, 1, T] Tensor: [B, 1, T]
""" """
c = c.to(self.conv_pre.weight.device) c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad( c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
c, (self.inference_padding, self.inference_padding), "replicate"
)
return self.forward(c) return self.forward(c)
def remove_weight_norm(self): def remove_weight_norm(self):
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
assert not self.training assert not self.training
self.remove_weight_norm() self.remove_weight_norm()
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
super(SELayer, self).__init__() super(SELayer, self).__init__()
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
return torch.nn.functional.conv1d(x, self.filter).squeeze(1) return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module): class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies. """This is copied from 🐸TTS to remove it from the dependencies."""
"""
# pylint: disable=W0102 # pylint: disable=W0102
def __init__( def __init__(
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
return criterion, state["step"] return criterion, state["step"]
return criterion return criterion
class HifiDecoder(torch.nn.Module): class HifiDecoder(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
""" """
return self.forward(c, g=g) return self.forward(c, g=g)
def load_checkpoint( def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
self, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# remove unused keys # remove unused keys
state = state["model"] state = state["model"]

View File

@ -1,26 +1,27 @@
# Adapted from: https://github.com/LowinLi/transformers-stream-generator # Adapted from: https://github.com/LowinLi/transformers-stream-generator
import copy
import inspect
import random
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from transformers import ( from transformers import (
BeamSearchScorer,
ConstrainedBeamSearchScorer,
DisjunctiveConstraint,
GenerationConfig, GenerationConfig,
GenerationMixin, GenerationMixin,
LogitsProcessorList, LogitsProcessorList,
StoppingCriteriaList,
DisjunctiveConstraint,
BeamSearchScorer,
PhrasalConstraint, PhrasalConstraint,
ConstrainedBeamSearchScorer,
PreTrainedModel, PreTrainedModel,
StoppingCriteriaList,
) )
import numpy as np
import random
import warnings
import inspect
from transformers.generation.utils import GenerateOutput, SampleOutput, logger from transformers.generation.utils import GenerateOutput, SampleOutput, logger
import torch
from typing import Callable, List, Optional, Union
from torch import nn
import torch.distributed as dist
import copy
def setup_seed(seed): def setup_seed(seed):
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
generation_config: Optional[StreamGenerationConfig] = None, generation_config: Optional[StreamGenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = False, synced_gpus: Optional[bool] = False,
seed=0, seed=0,
**kwargs, **kwargs,
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
# legacy: users may modify the model configuration to control generation -- update the generation config # legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config # model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config: if self.generation_config._from_model_config:
new_generation_config = StreamGenerationConfig.from_model_config( new_generation_config = StreamGenerationConfig.from_model_config(self.config)
self.config
)
if new_generation_config != self.generation_config: if new_generation_config != self.generation_config:
warnings.warn( warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a" "You have modified the pretrained model configuration to control generation. This is a"
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
generation_config = self.generation_config generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update( model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
**kwargs
) # All unused kwargs must be model kwargs
# self._validate_model_kwargs(model_kwargs.copy()) # self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined # 2. Set generation parameters if not already defined
logits_processor = ( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
if ( if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if model_kwargs.get("attention_mask", None) is None: if model_kwargs.get("attention_mask", None) is None:
logger.warning( logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe " "The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
eos_token_id = generation_config.eos_token_id eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list): if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0] eos_token_id = eos_token_id[0]
logger.warning( logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
generation_config.pad_token_id = eos_token_id generation_config.pad_token_id = eos_token_id
# 3. Define model inputs # 3. Define model inputs
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set( accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
inspect.signature(self.forward).parameters.keys()
)
requires_attention_mask = "encoder_outputs" not in model_kwargs requires_attention_mask = "encoder_outputs" not in model_kwargs
if ( if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs.get("attention_mask", None) is None model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
inputs_tensor, inputs_tensor,
generation_config.pad_token_id, generation_config.pad_token_id,
generation_config.eos_token_id, generation_config.eos_token_id,
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
if not self.config.is_encoder_decoder: if not self.config.is_encoder_decoder:
if ( if (
generation_config.pad_token_id is not None generation_config.pad_token_id is not None
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
> 0
): ):
logger.warning( logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct " "A decoder-only architecture is being used, but right-padding was detected! For correct "
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = ( has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn( warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
UserWarning, UserWarning,
) )
elif has_default_max_length and generation_config.max_new_tokens is not None: elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = ( generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
generation_config.max_new_tokens + input_ids_seq_length elif not has_default_max_length and generation_config.max_new_tokens is not None:
)
elif (
not has_default_max_length and generation_config.max_new_tokens is not None
):
raise ValueError( raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the" " limit to the generated output length. Remove one of those arguments. Please refer to the"
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
if ( if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
generation_config.min_length is not None
and generation_config.min_length > generation_config.max_length
):
raise ValueError( raise ValueError(
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})" f" the maximum length ({generation_config.max_length})"
) )
if input_ids_seq_length >= generation_config.max_length: if input_ids_seq_length >= generation_config.max_length:
input_ids_string = ( input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
)
logger.warning( logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
# 7. determine generation mode # 7. determine generation mode
is_constraint_gen_mode = ( is_constraint_gen_mode = (
generation_config.constraints is not None generation_config.constraints is not None or generation_config.force_words_ids is not None
or generation_config.force_words_ids is not None
) )
is_contrastive_search_gen_mode = ( is_contrastive_search_gen_mode = (
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
) )
if generation_config.num_beam_groups > generation_config.num_beams: if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError( raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
if is_group_beam_gen_mode and generation_config.do_sample is True: if is_group_beam_gen_mode and generation_config.do_sample is True:
raise ValueError( raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
) )
elif is_beam_gen_mode: elif is_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError( raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError( raise ValueError("`max_length` needs to be a stopping_criteria for now.")
"`max_length` needs to be a stopping_criteria for now."
)
# 11. prepare beam search scorer # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
logits_warper = self._get_logits_warper(generation_config) logits_warper = self._get_logits_warper(generation_config)
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError( raise ValueError("`max_length` needs to be a stopping_criteria for now.")
"`max_length` needs to be a stopping_criteria for now."
)
# 12. prepare beam search scorer # 12. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences, batch_size=batch_size * generation_config.num_return_sequences,
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
# 13. interleave input_ids with `num_beams` additional sequences per batch # 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
expand_size=generation_config.num_beams expand_size=generation_config.num_beams * generation_config.num_return_sequences,
* generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs, **model_kwargs,
) )
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
elif is_group_beam_gen_mode: elif is_group_beam_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError( raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if generation_config.num_beams % generation_config.num_beam_groups != 0: if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError( raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
)
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError( raise ValueError("`max_length` needs to be a stopping_criteria for now.")
"`max_length` needs to be a stopping_criteria for now."
)
has_default_typical_p = ( has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
)
if not has_default_typical_p: if not has_default_typical_p:
raise ValueError( raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
"Decoder argument `typical_p` is not supported with beam groups."
)
# 11. prepare beam search scorer # 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
elif is_constraint_gen_mode: elif is_constraint_gen_mode:
if generation_config.num_return_sequences > generation_config.num_beams: if generation_config.num_return_sequences > generation_config.num_beams:
raise ValueError( raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError( raise ValueError("`max_length` needs to be a stopping_criteria for now.")
"`max_length` needs to be a stopping_criteria for now."
)
if generation_config.num_beams <= 1: if generation_config.num_beams <= 1:
raise ValueError( raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
"`num_beams` needs to be greater than 1 for constrained generation."
)
if generation_config.do_sample: if generation_config.do_sample:
raise ValueError( raise ValueError("`do_sample` needs to be false for constrained generation.")
"`do_sample` needs to be false for constrained generation."
)
if ( if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
generation_config.num_beam_groups is not None raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
and generation_config.num_beam_groups > 1
):
raise ValueError(
"`num_beam_groups` not supported yet for constrained generation."
)
final_constraints = [] final_constraints = []
if generation_config.constraints is not None: if generation_config.constraints is not None:
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
if isinstance(word_ids[0], list): if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0: if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror() typeerror()
if any( if any(not isinstance(token_ids, list) for token_ids in word_ids):
not isinstance(token_ids, list) for token_ids in word_ids
):
typeerror() typeerror()
if any( if any(
any( any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
(not isinstance(token_id, int) or token_id < 0)
for token_id in token_ids
)
for token_ids in word_ids for token_ids in word_ids
): ):
typeerror() typeerror()
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
else: else:
if not isinstance(word_ids, list) or len(word_ids) == 0: if not isinstance(word_ids, list) or len(word_ids) == 0:
typeerror() typeerror()
if any( if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
(not isinstance(token_id, int) or token_id < 0)
for token_id in word_ids
):
typeerror() typeerror()
constraint = PhrasalConstraint(word_ids) constraint = PhrasalConstraint(word_ids)
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
```""" ```"""
# init values # init values
logits_processor = ( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
if max_length is not None: if max_length is not None:
warnings.warn( warnings.warn(
"`max_length` is deprecated in this function, use" "`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning, UserWarning,
) )
stopping_criteria = validate_stopping_criteria( stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
stopping_criteria, max_length logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
logits_warper = ( eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
logits_warper if logits_warper is not None else LogitsProcessorList()
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
output_scores = ( output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
output_attentions = ( output_attentions = (
output_attentions output_attentions if output_attentions is not None else self.generation_config.output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
) )
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
if output_hidden_states is not None
else self.generation_config.output_hidden_states
) )
return_dict_in_generate = ( return_dict_in_generate = (
return_dict_in_generate return_dict_in_generate
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = ( decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
() if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
) decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
if synced_gpus: if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence # The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor( this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
0.0 if this_peer_finished else 1.0
).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise # send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then # did all peers finish? the reduced sum will be 0.0 then
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
scores += (next_token_scores,) scores += (next_token_scores,)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
) )
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,) cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
(outputs.decoder_hidden_states,) (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
) )
# sample # sample
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
# finished sentences should have their next token be a padding token # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
if pad_token_id is None: if pad_token_id is None:
raise ValueError( raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined." next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None: if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul( unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
(sum(next_tokens != i for i in eos_token_id)).long()
)
# stop when each sentence is finished, or if we exceed the maximum length # stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
@ -1007,22 +885,17 @@ def init_stream_support():
if __name__ == "__main__": if __name__ == "__main__":
from transformers import PreTrainedModel from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from transformers import AutoTokenizer, AutoModelForCausalLM
PreTrainedModel.generate = NewGenerationMixin.generate PreTrainedModel.generate = NewGenerationMixin.generate
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
"bigscience/bloom-560m", torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model = model.to("cuda:0") model = model.to("cuda:0")
model = model.eval() model = model.eval()
prompt_text = "hello? \n" prompt_text = "hello? \n"
input_ids = tokenizer( input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
prompt_text, return_tensors="pt", add_special_tokens=False
).input_ids
input_ids = input_ids.to("cuda:0") input_ids = input_ids.to("cuda:0")
with torch.no_grad(): with torch.no_grad():

View File

@ -1,16 +1,18 @@
import os import os
import random import random
import sys import sys
import numpy as np
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import torchaudio import torchaudio
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
torch.set_num_threads(1) torch.set_num_threads(1)
def key_samples_by_col(samples, col): def key_samples_by_col(samples, col):
"""Returns a dictionary of samples keyed by language.""" """Returns a dictionary of samples keyed by language."""
samples_by_col = {} samples_by_col = {}
@ -50,7 +52,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
def load_audio(audiopath, sampling_rate): def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark # better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == '.mp3': if audiopath[-4:] == ".mp3":
# it uses torchaudio with sox backend to load mp3 # it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath) audio, lsr = torchaudio_sox_load(audiopath)
else: else:
@ -72,6 +74,7 @@ def load_audio(audiopath, sampling_rate):
audio.clip_(-1, 1) audio.clip_(-1, 1)
return audio return audio
class XTTSDataset(torch.utils.data.Dataset): class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
self.config = config self.config = config
@ -108,9 +111,11 @@ class XTTSDataset(torch.utils.data.Dataset):
except: except:
pass pass
# Basically, this audio file is nonexistent or too long to be supported by the dataset. # Basically, this audio file is nonexistent or too long to be supported by the dataset.
if wav is None or \ if (
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ wav is None
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len): or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
continue continue
new_samples.append(sample) new_samples.append(sample)
self.samples = new_samples self.samples = new_samples
@ -125,9 +130,9 @@ class XTTSDataset(torch.utils.data.Dataset):
return tokens return tokens
def load_item(self, sample): def load_item(self, sample):
text = str(sample['text']) text = str(sample["text"])
tseq = self.get_text(text, sample["language"]) tseq = self.get_text(text, sample["language"])
audiopath = sample['audio_file'] audiopath = sample["audio_file"]
wav = load_audio(audiopath, self.sample_rate) wav = load_audio(audiopath, self.sample_rate)
if text is None or len(text.strip()) == 0: if text is None or len(text.strip()) == 0:
raise ValueError raise ValueError
@ -136,7 +141,9 @@ class XTTSDataset(torch.utils.data.Dataset):
raise ValueError raise ValueError
# get a slice from GT to condition the model # get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval) cond, cond_len, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
)
return tseq, audiopath, wav, cond, cond_len, cond_idxs return tseq, audiopath, wav, cond, cond_len, cond_idxs
@ -170,26 +177,30 @@ class XTTSDataset(torch.utils.data.Dataset):
return self[1] return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples # check if the audio and text size limits and if it out of the limits, added it failed_samples
if wav is None or \ if (
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ wav is None
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len): or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
):
# Basically, this audio file is nonexistent or too long to be supported by the dataset. # Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None: if self.debug_failures and wav is not None and tseq is not None:
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") print(
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
)
self.failed_samples.add(sample_id) self.failed_samples.add(sample_id)
return self[1] return self[1]
res = { res = {
# 'real_text': text, # 'real_text': text,
'text': tseq, "text": tseq,
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), "text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
'wav': wav, "wav": wav,
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
'filenames': audiopath, "filenames": audiopath,
'conditioning': cond.unsqueeze(1), "conditioning": cond.unsqueeze(1),
'cond_lens': torch.tensor(cond_len, dtype=torch.long), "cond_lens": torch.tensor(cond_len, dtype=torch.long),
'cond_idxs': torch.tensor(cond_idxs), "cond_idxs": torch.tensor(cond_idxs),
} }
return res return res
@ -223,7 +234,7 @@ class XTTSDataset(torch.utils.data.Dataset):
for i in range(B): for i in range(B):
text = batch["text"][i] text = batch["text"][i]
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text) text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
wav = batch['wav'][i] wav = batch["wav"][i]
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav) wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded batch["wav"] = wav_padded

View File

@ -1,37 +1,30 @@
import os import os
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torchaudio
import torch.nn as nn import torch.nn as nn
import torchaudio
from coqpit import Coqpit
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import sys
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.base_tts import BaseTTS
from coqpit import Coqpit
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.datasets.dataset import TTSDataset
from trainer.torch import DistributedSampler from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
from TTS.tts.layers.xtts.dvae import DiscreteVAE
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
@dataclass @dataclass
class GPTTrainerConfig(XttsConfig): class GPTTrainerConfig(XttsConfig):
@ -42,6 +35,7 @@ class GPTTrainerConfig(XttsConfig):
weighted_loss_multipliers: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {})
test_sentences: List[dict] = field(default_factory=lambda: []) test_sentences: List[dict] = field(default_factory=lambda: [])
@dataclass @dataclass
class XttsAudioConfig(XttsAudioConfig): class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050 dvae_sample_rate: int = 22050
@ -68,14 +62,15 @@ class GPTArgs(XttsArgs):
def callback_clearml_load_save(operation_type, model_info): def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload # return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ('load', 'save') assert operation_type in ("load", "save")
# print(operation_type, model_info.__dict__) # print(operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__['local_model_path']: if "similarities.pth" in model_info.__dict__["local_model_path"]:
return None return None
return model_info return model_info
class GPTTrainer(BaseTTS): class GPTTrainer(BaseTTS):
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit):
""" """
@ -89,18 +84,17 @@ class GPTTrainer(BaseTTS):
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
# init gpt encoder and hifigan decoder # init gpt encoder and hifigan decoder
self.xtts.init_models() self.xtts.init_models()
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
if self.args.xtts_checkpoint: if self.args.xtts_checkpoint:
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
# set mel stats
if self.args.mel_norm_file:
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
# load GPT if available # load GPT if available
if self.args.gpt_checkpoint: if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load( gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
self.args.gpt_checkpoint, map_location=torch.device("cpu")
)
# deal with coqui Trainer exported model # deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!") print("Coqui Trainer checkpoint detected! Converting it!")
@ -115,8 +109,13 @@ class GPTTrainer(BaseTTS):
del gpt_checkpoint[key] del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape: if (
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] "text_embedding.weight" in gpt_checkpoint
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
):
num_new_tokens = (
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
)
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head) # add new tokens to a linear layer (text_head)
@ -156,7 +155,7 @@ class GPTTrainer(BaseTTS):
mel_fmin=0, mel_fmin=0,
mel_fmax=8000, mel_fmax=8000,
n_mel_channels=80, n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file mel_norm_file=self.args.mel_norm_file,
) )
# Load DVAE # Load DVAE
@ -175,17 +174,18 @@ class GPTTrainer(BaseTTS):
self.dvae.eval() self.dvae.eval()
if self.args.dvae_checkpoint: if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load( dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
self.args.dvae_checkpoint, map_location=torch.device("cpu")
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False) self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint) print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
else: else:
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!") raise RuntimeError(
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
)
# Mel spectrogram extractor for DVAE # Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate) self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
)
@property @property
def device(self): def device(self):
@ -203,7 +203,9 @@ class GPTTrainer(BaseTTS):
cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2) cond_idxs: cond start and end indexs, (b, 2)
""" """
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
)
return losses return losses
@torch.no_grad() @torch.no_grad()
@ -215,7 +217,9 @@ class GPTTrainer(BaseTTS):
test_audios = {} test_audios = {}
print(" | > Synthesizing test sentences.") print(" | > Synthesizing test sentences.")
for idx, s_info in enumerate(self.config.test_sentences): for idx, s_info in enumerate(self.config.test_sentences):
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"] wav = self.xtts.synthesize(
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
)["wav"]
test_audios["{}-audio".format(idx)] = wav test_audios["{}-audio".format(idx)] = wav
# delete inference layers # delete inference layers
@ -300,6 +304,7 @@ class GPTTrainer(BaseTTS):
# ignore similarities.pth on clearml save/upload # ignore similarities.pth on clearml save/upload
if self.config.dashboard_logger.lower() == "clearml": if self.config.dashboard_logger.lower() == "clearml":
from clearml.binding.frameworks import WeightsFileHandler from clearml.binding.frameworks import WeightsFileHandler
WeightsFileHandler.add_pre_callback(callback_clearml_load_save) WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
@torch.no_grad() @torch.no_grad()
@ -367,16 +372,23 @@ class GPTTrainer(BaseTTS):
return loader return loader
def get_optimizer(self) -> List: def get_optimizer(self) -> List:
"""Initiate and return the optimizer based on the config parameters. """Initiate and return the optimizer based on the config parameters."""
"""
# ToDo: deal with multi GPU training # ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights: if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model # parameters to only GPT model
net = self.xtts.gpt net = self.xtts.gpt
# normalizations # normalizations
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, norm_modules = (
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) nn.BatchNorm2d,
nn.InstanceNorm2d,
nn.BatchNorm1d,
nn.InstanceNorm1d,
nn.BatchNorm3d,
nn.InstanceNorm3d,
nn.GroupNorm,
nn.LayerNorm,
)
# nn.Embedding # nn.Embedding
emb_modules = (nn.Embedding, nn.EmbeddingBag) emb_modules = (nn.Embedding, nn.EmbeddingBag)
@ -390,7 +402,7 @@ class GPTTrainer(BaseTTS):
v.is_norm = isinstance(m, norm_modules) v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules) v.is_emb = isinstance(m, emb_modules)
fpn = '%s.%s' % (mn, k) if mn else k # full param name fpn = "%s.%s" % (mn, k) if mn else k # full param name
all_param_names.add(fpn) all_param_names.add(fpn)
param_map[fpn] = v param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb: if v.is_bias or v.is_norm or v.is_emb:
@ -402,8 +414,8 @@ class GPTTrainer(BaseTTS):
params_weights = [param_map[k] for k in params_names_weights] params_weights = [param_map[k] for k in params_names_weights]
groups = [ groups = [
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]}, {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
{ 'params': params_notweights, 'weight_decay': 0} {"params": params_notweights, "weight_decay": 0},
] ]
# torch.optim.AdamW # torch.optim.AdamW
opt = get_optimizer( opt = get_optimizer(
@ -461,4 +473,3 @@ class GPTTrainer(BaseTTS):
Defaults to None. Defaults to None.
""" """
return GPTTrainer(config) return GPTTrainer(config)

View File

@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
init_stream_support() init_stream_support()
def load_audio(audiopath, sr=22050): def load_audio(audiopath, sr=22050):
""" """
Load an audio file from disk and resample it to the specified sampling rate. Load an audio file from disk and resample it to the specified sampling rate.
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token, stop_audio_token=self.args.gpt_stop_audio_token,
) )
if self.args.use_hifigan: if self.args.use_hifigan:
self.hifigan_decoder = HifiDecoder( self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate, input_sample_rate=self.args.input_sample_rate,
@ -414,14 +414,13 @@ class Xtts(BaseTTS):
return diffusion_latent return diffusion_latent
@torch.inference_mode() @torch.inference_mode()
def get_speaker_embedding( def get_speaker_embedding(self, audio_path):
self,
audio_path
):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"]) audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward( speaker_embedding = (
audio.to(self.device), l2_norm=True self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
).unsqueeze(-1).to(self.device) .unsqueeze(-1)
.to(self.device)
)
return speaker_embedding return speaker_embedding
def get_conditioning_latents( def get_conditioning_latents(
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
( (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
gpt_cond_latent, audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
diffusion_conditioning, )
speaker_embedding
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
return self.inference( return self.inference(
text, text,
language, language,
@ -721,7 +718,9 @@ class Xtts(BaseTTS):
decoder="hifigan", decoder="hifigan",
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -827,6 +826,15 @@ class Xtts(BaseTTS):
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"] ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
for key in list(checkpoint.keys()): for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
coqui_trainer_checkpoint = True
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key
# remove unused keys
if key.split(".")[0] in ignore_keys: if key.split(".")[0] in ignore_keys:
del checkpoint[key] del checkpoint[key]

View File

@ -0,0 +1,145 @@
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
# Define here the dataset used
config_ljspeech = BaseDatasetConfig(
formatter="ljspeech",
dataset_name="ljspeech",
path="/raid/datasets/LJSpeech-1.1_24khz/",
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
language="en",
)
DATASETS_CONFIG_LIST = [config_ljspeech]
def freeze_layers(trainer):
pass
def main():
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
xtts_checkpoint="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth", # checkpoint path of the model that you want to fine-tune
tokenizer_file="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json",
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
)
config = GPTTrainerConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=10000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[
{
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "en",
},
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "en",
},
{
"text": "Levei muito tempo para desenvolver uma voz e agora que a tenho não vou ficar calado .",
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
"language": "pt",
},
],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(
restore_path=RESTORE_PATH,
skip_train_epoch=SKIP_TRAIN_EPOCH,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
callbacks={"on_epoch_start": freeze_layers},
)
trainer.fit()
if __name__ == "__main__":
RUN_NAME = "GPT_XTTS_LJSpeech_fixed"
PROJECT_NAME = "XTTS_trainer"
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_v1_FT/"
# DASHBOARD_LOGGER = "clearml"
# LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_v1/"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
RESTORE_PATH = None
SKIP_TRAIN_EPOCH = False
START_WITH_EVAL = True
BATCH_SIZE = 3
GRAD_ACUMM_STEPS = 28 * 3
# debug
# DASHBOARD_LOGGER = "tensorboard"
# LOGGER_URI = None
# RESTORE_PATH = None
# BATCH_SIZE = 2
# GRAD_ACUMM_STEPS = 1
main()

View File

@ -2,9 +2,7 @@ from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig( config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
formatter="coqui", formatter="coqui",
@ -252,11 +250,16 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it] # DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it] DATASETS_CONFIG_LIST = [
config_coqui_MLS_metadata_test_with_previous_audio_key_de,
config_coqui_mls_italian_metadata_with_previous_audio_key_it,
]
def freeze_layers(trainer): def freeze_layers(trainer):
pass pass
def main(): def main():
# init args and config # init args and config
model_args = GPTArgs( model_args = GPTArgs(
@ -274,10 +277,7 @@ def main():
gpt_stop_audio_token=8193, gpt_stop_audio_token=8193,
) )
audio_config = XttsAudioConfig( audio_config = XttsAudioConfig(
sample_rate=22050, # GPT SR sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
dvae_sample_rate=22050,
diffusion_sample_rate=24000,
output_sample_rate=24000
) )
config = GPTTrainerConfig( config = GPTTrainerConfig(
output_path=OUT_PATH, output_path=OUT_PATH,
@ -303,20 +303,26 @@ def main():
save_checkpoints=True, save_checkpoints=True,
# target_loss="loss", # target_loss="loss",
print_eval=False, print_eval=False,
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW", optimizer="AdamW",
optimizer_wd_only_on_weights=True, optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2}, optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate lr=5e-06, # learning rate
# lr=1e-4, # learning rate
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
lr_scheduler="MultiStepLR", lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme # it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[ test_sentences=[
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, {
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"}, "text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
] "speaker_wav": "/raid/edresson/dev/ref.wav",
"language": "en",
},
{
"text": "This cake is great. It's so delicious and moist.",
"speaker_wav": "/raid/edresson/dev/ref.wav",
"language": "en",
},
],
) )
# init the model from config # init the model from config
@ -332,13 +338,18 @@ def main():
# init the trainer and 🚀 # init the trainer and 🚀
trainer = Trainer( trainer = Trainer(
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS), TrainerArgs(
restore_path=RESTORE_PATH,
skip_train_epoch=SKIP_TRAIN_EPOCH,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config, config,
output_path=OUT_PATH, output_path=OUT_PATH,
model=model, model=model,
train_samples=train_samples, train_samples=train_samples,
eval_samples=eval_samples, eval_samples=eval_samples,
callbacks={"on_epoch_start": freeze_layers} callbacks={"on_epoch_start": freeze_layers},
) )
trainer.fit() trainer.fit()
@ -362,6 +373,4 @@ if __name__ == "__main__":
BATCH_SIZE = 2 BATCH_SIZE = 2
GRAD_ACUMM_STEPS = 1 GRAD_ACUMM_STEPS = 1
main() main()

View File

@ -99,6 +99,7 @@ def test_xtts_streaming():
"""Testing the new inference_stream method""" """Testing the new inference_stream method"""
from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts from TTS.tts.models.xtts import Xtts
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav") speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
config = XttsConfig() config = XttsConfig()
@ -115,7 +116,7 @@ def test_xtts_streaming():
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en", "en",
gpt_cond_latent, gpt_cond_latent,
speaker_embedding speaker_embedding,
) )
wav_chuncks = [] wav_chuncks = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):