mirror of https://github.com/coqui-ai/TTS.git
Bug Fix on inference using XTTS trainer checkpoint
This commit is contained in:
parent
c4ceaabe2c
commit
9e3598c3b7
|
@ -197,6 +197,7 @@ class GPT(nn.Module):
|
|||
|
||||
if use_deepspeed:
|
||||
import deepspeed
|
||||
|
||||
self.ds_engine = deepspeed.init_inference(
|
||||
model=self.gpt_inference.half(), # Transformers models
|
||||
mp_size=1, # Number of GPU
|
||||
|
@ -451,7 +452,7 @@ class GPT(nn.Module):
|
|||
if cond_idxs is not None:
|
||||
for idx, r in enumerate(cond_idxs.squeeze()):
|
||||
l = r[1] - r[0]
|
||||
attn_mask_cond[idx, l : ] = 0.0
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
import torchaudio
|
||||
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
|
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||
|
||||
# initial upsampling layers
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||
# upsampling layers
|
||||
self.ups = nn.ModuleList()
|
||||
|
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
|
|||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
# post convolution layer
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||
if cond_channels > 0:
|
||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||
|
||||
|
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
Tensor: [B, 1, T]
|
||||
"""
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c, (self.inference_padding, self.inference_padding), "replicate"
|
||||
)
|
||||
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
|
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
|
|||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
def __init__(self, channel, reduction=8):
|
||||
super(SELayer, self).__init__()
|
||||
|
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
|
|||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||
|
||||
|
||||
|
||||
class ResNetSpeakerEncoder(nn.Module):
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies.
|
||||
"""
|
||||
"""This is copied from 🐸TTS to remove it from the dependencies."""
|
||||
|
||||
# pylint: disable=W0102
|
||||
def __init__(
|
||||
|
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
return criterion, state["step"]
|
||||
return criterion
|
||||
|
||||
|
||||
class HifiDecoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
|
|||
"""
|
||||
return self.forward(c, g=g)
|
||||
|
||||
def load_checkpoint(
|
||||
self, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# remove unused keys
|
||||
state = state["model"]
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
BeamSearchScorer,
|
||||
ConstrainedBeamSearchScorer,
|
||||
DisjunctiveConstraint,
|
||||
GenerationConfig,
|
||||
GenerationMixin,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
DisjunctiveConstraint,
|
||||
BeamSearchScorer,
|
||||
PhrasalConstraint,
|
||||
ConstrainedBeamSearchScorer,
|
||||
PreTrainedModel,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import numpy as np
|
||||
import random
|
||||
import warnings
|
||||
import inspect
|
||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||
import torch
|
||||
from typing import Callable, List, Optional, Union
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
import copy
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
|
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config: Optional[StreamGenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[
|
||||
Callable[[int, torch.Tensor], List[int]]
|
||||
] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
seed=0,
|
||||
**kwargs,
|
||||
|
@ -125,7 +124,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
"""
|
||||
#setup_seed(seed)
|
||||
# setup_seed(seed)
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
|
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||
# model attribute accordingly, if it was created from the model config
|
||||
if self.generation_config._from_model_config:
|
||||
new_generation_config = StreamGenerationConfig.from_model_config(
|
||||
self.config
|
||||
)
|
||||
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
|
||||
if new_generation_config != self.generation_config:
|
||||
warnings.warn(
|
||||
"You have modified the pretrained model configuration to control generation. This is a"
|
||||
|
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
|
|||
generation_config = self.generation_config
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(
|
||||
**kwargs
|
||||
) # All unused kwargs must be model kwargs
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
# self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
if (
|
||||
generation_config.pad_token_id is None
|
||||
and generation_config.eos_token_id is not None
|
||||
):
|
||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
|
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
# 3. Define model inputs
|
||||
|
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
|
|||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
accepts_attention_mask = "attention_mask" in set(
|
||||
inspect.signature(self.forward).parameters.keys()
|
||||
)
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if (
|
||||
model_kwargs.get("attention_mask", None) is None
|
||||
and requires_attention_mask
|
||||
and accepts_attention_mask
|
||||
):
|
||||
model_kwargs[
|
||||
"attention_mask"
|
||||
] = self._prepare_attention_mask_for_generation(
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
|
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if not self.config.is_encoder_decoder:
|
||||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
||||
> 0
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
|
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
has_default_max_length = (
|
||||
kwargs.get("max_length") is None
|
||||
and generation_config.max_length is not None
|
||||
)
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||
warnings.warn(
|
||||
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
|
||||
|
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
|
|||
UserWarning,
|
||||
)
|
||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
elif (
|
||||
not has_default_max_length and generation_config.max_new_tokens is not None
|
||||
):
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
||||
raise ValueError(
|
||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
||||
|
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
|
|||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
|
||||
if (
|
||||
generation_config.min_length is not None
|
||||
and generation_config.min_length > generation_config.max_length
|
||||
):
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||
f" the maximum length ({generation_config.max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= generation_config.max_length:
|
||||
input_ids_string = (
|
||||
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
)
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||
|
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# 7. determine generation mode
|
||||
is_constraint_gen_mode = (
|
||||
generation_config.constraints is not None
|
||||
or generation_config.force_words_ids is not None
|
||||
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
||||
)
|
||||
|
||||
is_contrastive_search_gen_mode = (
|
||||
|
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
|
||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_beam_groups` has to be smaller or equal to `num_beams`"
|
||||
)
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
||||
raise ValueError(
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
|
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
|
|||
)
|
||||
elif is_beam_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
|
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
logits_warper = self._get_logits_warper(generation_config)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
# 12. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size * generation_config.num_return_sequences,
|
||||
|
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams
|
||||
* generation_config.num_return_sequences,
|
||||
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
elif is_group_beam_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
||||
raise ValueError(
|
||||
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
|
||||
)
|
||||
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
has_default_typical_p = (
|
||||
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||
)
|
||||
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||
if not has_default_typical_p:
|
||||
raise ValueError(
|
||||
"Decoder argument `typical_p` is not supported with beam groups."
|
||||
)
|
||||
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
|
||||
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
|
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
elif is_constraint_gen_mode:
|
||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||
raise ValueError(
|
||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
||||
)
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError(
|
||||
"`max_length` needs to be a stopping_criteria for now."
|
||||
)
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
if generation_config.num_beams <= 1:
|
||||
raise ValueError(
|
||||
"`num_beams` needs to be greater than 1 for constrained generation."
|
||||
)
|
||||
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
|
||||
|
||||
if generation_config.do_sample:
|
||||
raise ValueError(
|
||||
"`do_sample` needs to be false for constrained generation."
|
||||
)
|
||||
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
||||
|
||||
if (
|
||||
generation_config.num_beam_groups is not None
|
||||
and generation_config.num_beam_groups > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"`num_beam_groups` not supported yet for constrained generation."
|
||||
)
|
||||
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
|
||||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||
|
||||
final_constraints = []
|
||||
if generation_config.constraints is not None:
|
||||
|
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if isinstance(word_ids[0], list):
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any(
|
||||
not isinstance(token_ids, list) for token_ids in word_ids
|
||||
):
|
||||
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
||||
typeerror()
|
||||
if any(
|
||||
any(
|
||||
(not isinstance(token_id, int) or token_id < 0)
|
||||
for token_id in token_ids
|
||||
)
|
||||
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||
for token_ids in word_ids
|
||||
):
|
||||
typeerror()
|
||||
|
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
else:
|
||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||
typeerror()
|
||||
if any(
|
||||
(not isinstance(token_id, int) or token_id < 0)
|
||||
for token_id in word_ids
|
||||
):
|
||||
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
||||
typeerror()
|
||||
|
||||
constraint = PhrasalConstraint(word_ids)
|
||||
|
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
|
|||
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = (
|
||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
)
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use"
|
||||
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(
|
||||
stopping_criteria, max_length
|
||||
)
|
||||
logits_warper = (
|
||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
)
|
||||
pad_token_id = (
|
||||
pad_token_id
|
||||
if pad_token_id is not None
|
||||
else self.generation_config.pad_token_id
|
||||
)
|
||||
eos_token_id = (
|
||||
eos_token_id
|
||||
if eos_token_id is not None
|
||||
else self.generation_config.eos_token_id
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
output_scores = (
|
||||
output_scores
|
||||
if output_scores is not None
|
||||
else self.generation_config.output_scores
|
||||
)
|
||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.generation_config.output_attentions
|
||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.generation_config.output_hidden_states
|
||||
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
|
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(
|
||||
0.0 if this_peer_finished else 1.0
|
||||
).to(input_ids.device)
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
|
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
|
|||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.attentions,)
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# sample
|
||||
|
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
|
|||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
|
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
|
|||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
(sum(next_tokens != i for i in eos_token_id)).long()
|
||||
)
|
||||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
|
@ -1007,22 +885,17 @@ def init_stream_support():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
|
||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"bigscience/bloom-560m", torch_dtype=torch.float16
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||
model = model.to("cuda:0")
|
||||
model = model.eval()
|
||||
prompt_text = "hello? \n"
|
||||
input_ids = tokenizer(
|
||||
prompt_text, return_tensors="pt", add_special_tokens=False
|
||||
).input_ids
|
||||
input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
|
||||
input_ids = input_ids.to("cuda:0")
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
|
||||
torch.set_num_threads(1)
|
||||
|
||||
|
||||
def key_samples_by_col(samples, col):
|
||||
"""Returns a dictionary of samples keyed by language."""
|
||||
samples_by_col = {}
|
||||
|
@ -27,7 +29,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
|||
rel_clip = load_audio(gt_path, sample_rate)
|
||||
# if eval uses a middle size sample when it is possible to be more reproducible
|
||||
if is_eval:
|
||||
sample_length = int((min_sample_length + max_sample_length)/2)
|
||||
sample_length = int((min_sample_length + max_sample_length) / 2)
|
||||
else:
|
||||
sample_length = random.randint(min_sample_length, max_sample_length)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
|
@ -41,7 +43,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
|||
else:
|
||||
rand_start = random.randint(0, gap)
|
||||
|
||||
rand_end = rand_start+sample_length
|
||||
rand_end = rand_start + sample_length
|
||||
rel_clip = rel_clip[:, rand_start:rand_end]
|
||||
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
||||
cond_idxs = [rand_start, rand_end]
|
||||
|
@ -50,7 +52,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
|||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == '.mp3':
|
||||
if audiopath[-4:] == ".mp3":
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio_sox_load(audiopath)
|
||||
else:
|
||||
|
@ -72,6 +74,7 @@ def load_audio(audiopath, sampling_rate):
|
|||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||
self.config = config
|
||||
|
@ -103,16 +106,18 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
print(" > Filtering invalid eval samples!!")
|
||||
new_samples = []
|
||||
for sample in self.samples:
|
||||
try:
|
||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||
except:
|
||||
pass
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
continue
|
||||
new_samples.append(sample)
|
||||
try:
|
||||
tseq, _, wav, _, _, _ = self.load_item(sample)
|
||||
except:
|
||||
pass
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
if (
|
||||
wav is None
|
||||
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||
):
|
||||
continue
|
||||
new_samples.append(sample)
|
||||
self.samples = new_samples
|
||||
print(" > Total eval samples after filtering:", len(self.samples))
|
||||
|
||||
|
@ -125,9 +130,9 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
return tokens
|
||||
|
||||
def load_item(self, sample):
|
||||
text = str(sample['text'])
|
||||
text = str(sample["text"])
|
||||
tseq = self.get_text(text, sample["language"])
|
||||
audiopath = sample['audio_file']
|
||||
audiopath = sample["audio_file"]
|
||||
wav = load_audio(audiopath, self.sample_rate)
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
|
@ -136,7 +141,9 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
raise ValueError
|
||||
|
||||
# get a slice from GT to condition the model
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||
)
|
||||
|
||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
|
@ -151,7 +158,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||
sample = self.samples[lang][index]
|
||||
# a unique id for each sampel to deal with fails
|
||||
sample_id = lang+"_"+str(index)
|
||||
sample_id = lang + "_" + str(index)
|
||||
|
||||
# ignore samples that we already know that is not valid ones
|
||||
if sample_id in self.failed_samples:
|
||||
|
@ -170,26 +177,30 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
return self[1]
|
||||
|
||||
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
if (
|
||||
wav is None
|
||||
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||
):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures and wav is not None and tseq is not None:
|
||||
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
print(
|
||||
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
|
||||
)
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
res = {
|
||||
# 'real_text': text,
|
||||
'text': tseq,
|
||||
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||
'wav': wav,
|
||||
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
'filenames': audiopath,
|
||||
'conditioning': cond.unsqueeze(1),
|
||||
'cond_lens': torch.tensor(cond_len, dtype=torch.long),
|
||||
'cond_idxs': torch.tensor(cond_idxs),
|
||||
"text": tseq,
|
||||
"text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||
"wav": wav,
|
||||
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
"filenames": audiopath,
|
||||
"conditioning": cond.unsqueeze(1),
|
||||
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
||||
"cond_idxs": torch.tensor(cond_idxs),
|
||||
}
|
||||
return res
|
||||
|
||||
|
@ -223,8 +234,8 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
for i in range(B):
|
||||
text = batch["text"][i]
|
||||
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
|
||||
wav = batch['wav'][i]
|
||||
wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
||||
wav = batch["wav"][i]
|
||||
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
||||
|
||||
batch["wav"] = wav_padded
|
||||
batch["padded_text"] = text_padded
|
||||
|
|
|
@ -1,37 +1,30 @@
|
|||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import sys
|
||||
|
||||
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
|
||||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
|
||||
@dataclass
|
||||
class GPTTrainerConfig(XttsConfig):
|
||||
|
@ -42,6 +35,7 @@ class GPTTrainerConfig(XttsConfig):
|
|||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(XttsAudioConfig):
|
||||
dvae_sample_rate: int = 22050
|
||||
|
@ -55,27 +49,28 @@ class GPTArgs(XttsArgs):
|
|||
gpt_loss_mel_ce_weight: float = 1.0
|
||||
gpt_num_audio_tokens: int = 8194
|
||||
debug_loading_failures: bool = False
|
||||
max_wav_length: int = 255995 # ~11.6 seconds
|
||||
max_wav_length: int = 255995 # ~11.6 seconds
|
||||
max_text_length: int = 200
|
||||
tokenizer_file: str = ""
|
||||
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
||||
dvae_checkpoint: str = ""
|
||||
xtts_checkpoint: str = ""
|
||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||
|
||||
|
||||
def callback_clearml_load_save(operation_type, model_info):
|
||||
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
||||
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
||||
assert operation_type in ('load', 'save')
|
||||
assert operation_type in ("load", "save")
|
||||
# print(operation_type, model_info.__dict__)
|
||||
|
||||
if "similarities.pth" in model_info.__dict__['local_model_path']:
|
||||
if "similarities.pth" in model_info.__dict__["local_model_path"]:
|
||||
return None
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
class GPTTrainer(BaseTTS):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""
|
||||
|
@ -89,18 +84,17 @@ class GPTTrainer(BaseTTS):
|
|||
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||
# init gpt encoder and hifigan decoder
|
||||
self.xtts.init_models()
|
||||
# set mel stats
|
||||
if self.args.mel_norm_file:
|
||||
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||
|
||||
if self.args.xtts_checkpoint:
|
||||
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
||||
|
||||
# set mel stats
|
||||
if self.args.mel_norm_file:
|
||||
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||
|
||||
# load GPT if available
|
||||
if self.args.gpt_checkpoint:
|
||||
gpt_checkpoint = torch.load(
|
||||
self.args.gpt_checkpoint, map_location=torch.device("cpu")
|
||||
)
|
||||
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
|
||||
# deal with coqui Trainer exported model
|
||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||
print("Coqui Trainer checkpoint detected! Converting it!")
|
||||
|
@ -115,8 +109,13 @@ class GPTTrainer(BaseTTS):
|
|||
del gpt_checkpoint[key]
|
||||
|
||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
|
||||
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
if (
|
||||
"text_embedding.weight" in gpt_checkpoint
|
||||
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
|
||||
):
|
||||
num_new_tokens = (
|
||||
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
)
|
||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||
|
||||
# add new tokens to a linear layer (text_head)
|
||||
|
@ -156,7 +155,7 @@ class GPTTrainer(BaseTTS):
|
|||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
|
||||
# Load DVAE
|
||||
|
@ -175,17 +174,18 @@ class GPTTrainer(BaseTTS):
|
|||
|
||||
self.dvae.eval()
|
||||
if self.args.dvae_checkpoint:
|
||||
dvae_checkpoint = torch.load(
|
||||
self.args.dvae_checkpoint, map_location=torch.device("cpu")
|
||||
)
|
||||
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
|
||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
||||
else:
|
||||
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!")
|
||||
raise RuntimeError(
|
||||
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
|
||||
)
|
||||
|
||||
# Mel spectrogram extractor for DVAE
|
||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
||||
|
||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
|
||||
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
|
@ -203,7 +203,9 @@ class GPTTrainer(BaseTTS):
|
|||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
"""
|
||||
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||
losses = self.xtts.gpt(
|
||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
||||
)
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -215,7 +217,9 @@ class GPTTrainer(BaseTTS):
|
|||
test_audios = {}
|
||||
print(" | > Synthesizing test sentences.")
|
||||
for idx, s_info in enumerate(self.config.test_sentences):
|
||||
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
|
||||
wav = self.xtts.synthesize(
|
||||
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
|
||||
)["wav"]
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
|
||||
# delete inference layers
|
||||
|
@ -231,7 +235,7 @@ class GPTTrainer(BaseTTS):
|
|||
def format_batch(self, batch: Dict) -> Dict:
|
||||
return batch
|
||||
|
||||
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
|
||||
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
|
||||
def format_batch_on_device(self, batch):
|
||||
"""Compute spectrograms on the device."""
|
||||
batch["text_lengths"] = batch["text_lengths"]
|
||||
|
@ -241,10 +245,10 @@ class GPTTrainer(BaseTTS):
|
|||
# compute conditioning mel specs
|
||||
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
|
||||
B, num_cond_samples, C, T = batch["conditioning"].size()
|
||||
conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T)
|
||||
conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T)
|
||||
paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped)
|
||||
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
|
||||
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
|
||||
n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1)
|
||||
T_mel = paired_conditioning_mel.size(2)
|
||||
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
|
||||
# get the conditioning embeddings
|
||||
|
@ -300,6 +304,7 @@ class GPTTrainer(BaseTTS):
|
|||
# ignore similarities.pth on clearml save/upload
|
||||
if self.config.dashboard_logger.lower() == "clearml":
|
||||
from clearml.binding.frameworks import WeightsFileHandler
|
||||
|
||||
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -367,16 +372,23 @@ class GPTTrainer(BaseTTS):
|
|||
return loader
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the optimizer based on the config parameters.
|
||||
"""
|
||||
"""Initiate and return the optimizer based on the config parameters."""
|
||||
# ToDo: deal with multi GPU training
|
||||
if self.config.optimizer_wd_only_on_weights:
|
||||
# parameters to only GPT model
|
||||
net = self.xtts.gpt
|
||||
|
||||
# normalizations
|
||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||
norm_modules = (
|
||||
nn.BatchNorm2d,
|
||||
nn.InstanceNorm2d,
|
||||
nn.BatchNorm1d,
|
||||
nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d,
|
||||
nn.InstanceNorm3d,
|
||||
nn.GroupNorm,
|
||||
nn.LayerNorm,
|
||||
)
|
||||
# nn.Embedding
|
||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||
|
||||
|
@ -390,7 +402,7 @@ class GPTTrainer(BaseTTS):
|
|||
v.is_norm = isinstance(m, norm_modules)
|
||||
v.is_emb = isinstance(m, emb_modules)
|
||||
|
||||
fpn = '%s.%s' % (mn, k) if mn else k # full param name
|
||||
fpn = "%s.%s" % (mn, k) if mn else k # full param name
|
||||
all_param_names.add(fpn)
|
||||
param_map[fpn] = v
|
||||
if v.is_bias or v.is_norm or v.is_emb:
|
||||
|
@ -402,26 +414,26 @@ class GPTTrainer(BaseTTS):
|
|||
params_weights = [param_map[k] for k in params_names_weights]
|
||||
|
||||
groups = [
|
||||
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]},
|
||||
{ 'params': params_notweights, 'weight_decay': 0}
|
||||
{"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
|
||||
{"params": params_notweights, "weight_decay": 0},
|
||||
]
|
||||
# torch.optim.AdamW
|
||||
opt = get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
parameters=groups,
|
||||
)
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
parameters=groups,
|
||||
)
|
||||
opt._group_names = [params_names_weights, params_names_notweights]
|
||||
return opt
|
||||
|
||||
return get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
# optimize only for the GPT model
|
||||
parameters=self.xtts.gpt.parameters(),
|
||||
)
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
# optimize only for the GPT model
|
||||
parameters=self.xtts.gpt.parameters(),
|
||||
)
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the scheduler for the optimizer.
|
||||
|
@ -461,4 +473,3 @@ class GPTTrainer(BaseTTS):
|
|||
Defaults to None.
|
||||
"""
|
||||
return GPTTrainer(config)
|
||||
|
|
@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
|
|||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
init_stream_support()
|
||||
|
||||
|
||||
def load_audio(audiopath, sr=22050):
|
||||
"""
|
||||
Load an audio file from disk and resample it to the specified sampling rate.
|
||||
|
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
|
|||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
)
|
||||
|
||||
|
||||
if self.args.use_hifigan:
|
||||
self.hifigan_decoder = HifiDecoder(
|
||||
input_sample_rate=self.args.input_sample_rate,
|
||||
|
@ -414,14 +414,13 @@ class Xtts(BaseTTS):
|
|||
return diffusion_latent
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_speaker_embedding(
|
||||
self,
|
||||
audio_path
|
||||
):
|
||||
def get_speaker_embedding(self, audio_path):
|
||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
|
||||
audio.to(self.device), l2_norm=True
|
||||
).unsqueeze(-1).to(self.device)
|
||||
speaker_embedding = (
|
||||
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
|
||||
.unsqueeze(-1)
|
||||
.to(self.device)
|
||||
)
|
||||
return speaker_embedding
|
||||
|
||||
def get_conditioning_latents(
|
||||
|
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
|
|||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
speaker_embedding
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
||||
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
|
||||
)
|
||||
return self.inference(
|
||||
text,
|
||||
language,
|
||||
|
@ -721,7 +718,9 @@ class Xtts(BaseTTS):
|
|||
decoder="hifigan",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
assert hasattr(
|
||||
self, "hifigan_decoder"
|
||||
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
|
@ -827,6 +826,15 @@ class Xtts(BaseTTS):
|
|||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||
for key in list(checkpoint.keys()):
|
||||
# check if it is from the coqui Trainer if so convert it
|
||||
if key.startswith("xtts."):
|
||||
coqui_trainer_checkpoint = True
|
||||
new_key = key.replace("xtts.", "")
|
||||
checkpoint[new_key] = checkpoint[key]
|
||||
del checkpoint[key]
|
||||
key = new_key
|
||||
|
||||
# remove unused keys
|
||||
if key.split(".")[0] in ignore_keys:
|
||||
del checkpoint[key]
|
||||
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
|
||||
# Define here the dataset used
|
||||
config_ljspeech = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
dataset_name="ljspeech",
|
||||
path="/raid/datasets/LJSpeech-1.1_24khz/",
|
||||
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_ljspeech]
|
||||
|
||||
|
||||
def freeze_layers(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
|
||||
xtts_checkpoint="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth", # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json",
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
|
||||
)
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{
|
||||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||
"language": "en",
|
||||
},
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||
"language": "en",
|
||||
},
|
||||
{
|
||||
"text": "Levei muito tempo para desenvolver uma voz e agora que a tenho não vou ficar calado .",
|
||||
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||
"language": "pt",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(
|
||||
restore_path=RESTORE_PATH,
|
||||
skip_train_epoch=SKIP_TRAIN_EPOCH,
|
||||
start_with_eval=START_WITH_EVAL,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
callbacks={"on_epoch_start": freeze_layers},
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RUN_NAME = "GPT_XTTS_LJSpeech_fixed"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_v1_FT/"
|
||||
# DASHBOARD_LOGGER = "clearml"
|
||||
# LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_v1/"
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
RESTORE_PATH = None
|
||||
SKIP_TRAIN_EPOCH = False
|
||||
START_WITH_EVAL = True
|
||||
BATCH_SIZE = 3
|
||||
GRAD_ACUMM_STEPS = 28 * 3
|
||||
|
||||
# debug
|
||||
# DASHBOARD_LOGGER = "tensorboard"
|
||||
# LOGGER_URI = None
|
||||
# RESTORE_PATH = None
|
||||
# BATCH_SIZE = 2
|
||||
# GRAD_ACUMM_STEPS = 1
|
||||
|
||||
main()
|
|
@ -2,9 +2,7 @@ from trainer import Trainer, TrainerArgs
|
|||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||
|
||||
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
|
@ -252,32 +250,34 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
|
|||
|
||||
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
DATASETS_CONFIG_LIST = [
|
||||
config_coqui_MLS_metadata_test_with_previous_audio_key_de,
|
||||
config_coqui_mls_italian_metadata_with_previous_audio_key_it,
|
||||
]
|
||||
|
||||
|
||||
def freeze_layers(trainer):
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, # GPT SR
|
||||
dvae_sample_rate=22050,
|
||||
diffusion_sample_rate=24000,
|
||||
output_sample_rate=24000
|
||||
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
|
||||
)
|
||||
config = GPTTrainerConfig(
|
||||
output_path=OUT_PATH,
|
||||
|
@ -303,20 +303,26 @@ def main():
|
|||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch
|
||||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=True,
|
||||
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
# lr=1e-4, # learning rate
|
||||
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
|
||||
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
|
||||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
test_sentences=[
|
||||
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||
]
|
||||
{
|
||||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"speaker_wav": "/raid/edresson/dev/ref.wav",
|
||||
"language": "en",
|
||||
},
|
||||
{
|
||||
"text": "This cake is great. It's so delicious and moist.",
|
||||
"speaker_wav": "/raid/edresson/dev/ref.wav",
|
||||
"language": "en",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
|
@ -332,13 +338,18 @@ def main():
|
|||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS),
|
||||
TrainerArgs(
|
||||
restore_path=RESTORE_PATH,
|
||||
skip_train_epoch=SKIP_TRAIN_EPOCH,
|
||||
start_with_eval=START_WITH_EVAL,
|
||||
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||
),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
callbacks={"on_epoch_start": freeze_layers}
|
||||
callbacks={"on_epoch_start": freeze_layers},
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
@ -351,7 +362,7 @@ if __name__ == "__main__":
|
|||
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
|
||||
RESTORE_PATH = None
|
||||
SKIP_TRAIN_EPOCH = False
|
||||
START_WITH_EVAL = True
|
||||
START_WITH_EVAL = True
|
||||
BATCH_SIZE = 9
|
||||
GRAD_ACUMM_STEPS = 28
|
||||
|
||||
|
@ -362,6 +373,4 @@ if __name__ == "__main__":
|
|||
BATCH_SIZE = 2
|
||||
GRAD_ACUMM_STEPS = 1
|
||||
|
||||
|
||||
|
||||
main()
|
||||
|
|
|
@ -99,6 +99,7 @@ def test_xtts_streaming():
|
|||
"""Testing the new inference_stream method"""
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||
config = XttsConfig()
|
||||
|
@ -115,7 +116,7 @@ def test_xtts_streaming():
|
|||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
"en",
|
||||
gpt_cond_latent,
|
||||
speaker_embedding
|
||||
speaker_embedding,
|
||||
)
|
||||
wav_chuncks = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
|
|
Loading…
Reference in New Issue