mirror of https://github.com/coqui-ai/TTS.git
Bug Fix on inference using XTTS trainer checkpoint
This commit is contained in:
parent
c4ceaabe2c
commit
9e3598c3b7
|
@ -197,6 +197,7 @@ class GPT(nn.Module):
|
||||||
|
|
||||||
if use_deepspeed:
|
if use_deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
self.ds_engine = deepspeed.init_inference(
|
self.ds_engine = deepspeed.init_inference(
|
||||||
model=self.gpt_inference.half(), # Transformers models
|
model=self.gpt_inference.half(), # Transformers models
|
||||||
mp_size=1, # Number of GPU
|
mp_size=1, # Number of GPU
|
||||||
|
|
|
@ -1,13 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Conv1d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
@ -224,9 +223,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
self.cond_in_each_up_layer = cond_in_each_up_layer
|
self.cond_in_each_up_layer = cond_in_each_up_layer
|
||||||
|
|
||||||
# initial upsampling layers
|
# initial upsampling layers
|
||||||
self.conv_pre = weight_norm(
|
self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
|
||||||
)
|
|
||||||
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
resblock = ResBlock1 if resblock_type == "1" else ResBlock2
|
||||||
# upsampling layers
|
# upsampling layers
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
|
@ -246,14 +243,10 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for _, (k, d) in enumerate(
|
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
|
||||||
):
|
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
# post convolution layer
|
# post convolution layer
|
||||||
self.conv_post = weight_norm(
|
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias))
|
||||||
Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
|
|
||||||
)
|
|
||||||
if cond_channels > 0:
|
if cond_channels > 0:
|
||||||
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
|
@ -318,9 +311,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
Tensor: [B, 1, T]
|
Tensor: [B, 1, T]
|
||||||
"""
|
"""
|
||||||
c = c.to(self.conv_pre.weight.device)
|
c = c.to(self.conv_pre.weight.device)
|
||||||
c = torch.nn.functional.pad(
|
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
|
||||||
c, (self.inference_padding, self.inference_padding), "replicate"
|
|
||||||
)
|
|
||||||
return self.forward(c)
|
return self.forward(c)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
|
@ -342,6 +333,7 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
self.remove_weight_norm()
|
self.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
super(SELayer, self).__init__()
|
super(SELayer, self).__init__()
|
||||||
|
@ -425,10 +417,8 @@ class PreEmphasis(nn.Module):
|
||||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResNetSpeakerEncoder(nn.Module):
|
class ResNetSpeakerEncoder(nn.Module):
|
||||||
"""This is copied from 🐸TTS to remove it from the dependencies.
|
"""This is copied from 🐸TTS to remove it from the dependencies."""
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=W0102
|
# pylint: disable=W0102
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -620,6 +610,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
return criterion, state["step"]
|
return criterion, state["step"]
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
|
|
||||||
class HifiDecoder(torch.nn.Module):
|
class HifiDecoder(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -724,9 +715,7 @@ class HifiDecoder(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
return self.forward(c, g=g)
|
return self.forward(c, g=g)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
self, checkpoint_path, eval=False
|
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
# remove unused keys
|
# remove unused keys
|
||||||
state = state["model"]
|
state = state["model"]
|
||||||
|
|
|
@ -1,26 +1,27 @@
|
||||||
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
|
# Adapted from: https://github.com/LowinLi/transformers-stream-generator
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
import warnings
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
BeamSearchScorer,
|
||||||
|
ConstrainedBeamSearchScorer,
|
||||||
|
DisjunctiveConstraint,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
GenerationMixin,
|
GenerationMixin,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
StoppingCriteriaList,
|
|
||||||
DisjunctiveConstraint,
|
|
||||||
BeamSearchScorer,
|
|
||||||
PhrasalConstraint,
|
PhrasalConstraint,
|
||||||
ConstrainedBeamSearchScorer,
|
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
StoppingCriteriaList,
|
||||||
)
|
)
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
import warnings
|
|
||||||
import inspect
|
|
||||||
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
||||||
import torch
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
from torch import nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
def setup_seed(seed):
|
def setup_seed(seed):
|
||||||
|
@ -48,9 +49,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
generation_config: Optional[StreamGenerationConfig] = None,
|
generation_config: Optional[StreamGenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
Callable[[int, torch.Tensor], List[int]]
|
|
||||||
] = None,
|
|
||||||
synced_gpus: Optional[bool] = False,
|
synced_gpus: Optional[bool] = False,
|
||||||
seed=0,
|
seed=0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -134,9 +133,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# legacy: users may modify the model configuration to control generation -- update the generation config
|
# legacy: users may modify the model configuration to control generation -- update the generation config
|
||||||
# model attribute accordingly, if it was created from the model config
|
# model attribute accordingly, if it was created from the model config
|
||||||
if self.generation_config._from_model_config:
|
if self.generation_config._from_model_config:
|
||||||
new_generation_config = StreamGenerationConfig.from_model_config(
|
new_generation_config = StreamGenerationConfig.from_model_config(self.config)
|
||||||
self.config
|
|
||||||
)
|
|
||||||
if new_generation_config != self.generation_config:
|
if new_generation_config != self.generation_config:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"You have modified the pretrained model configuration to control generation. This is a"
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
|
@ -148,25 +145,14 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(
|
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||||
**kwargs
|
|
||||||
) # All unused kwargs must be model kwargs
|
|
||||||
# self._validate_model_kwargs(model_kwargs.copy())
|
# self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = (
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
)
|
|
||||||
stopping_criteria = (
|
|
||||||
stopping_criteria
|
|
||||||
if stopping_criteria is not None
|
|
||||||
else StoppingCriteriaList()
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
generation_config.pad_token_id is None
|
|
||||||
and generation_config.eos_token_id is not None
|
|
||||||
):
|
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
@ -175,9 +161,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
eos_token_id = generation_config.eos_token_id
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, list):
|
if isinstance(eos_token_id, list):
|
||||||
eos_token_id = eos_token_id[0]
|
eos_token_id = eos_token_id[0]
|
||||||
logger.warning(
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
|
||||||
)
|
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
|
@ -195,19 +179,11 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
inspect.signature(self.forward).parameters.keys()
|
|
||||||
)
|
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if (
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs.get("attention_mask", None) is None
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
and requires_attention_mask
|
|
||||||
and accepts_attention_mask
|
|
||||||
):
|
|
||||||
model_kwargs[
|
|
||||||
"attention_mask"
|
|
||||||
] = self._prepare_attention_mask_for_generation(
|
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
generation_config.pad_token_id,
|
generation_config.pad_token_id,
|
||||||
generation_config.eos_token_id,
|
generation_config.eos_token_id,
|
||||||
|
@ -217,8 +193,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
if not self.config.is_encoder_decoder:
|
if not self.config.is_encoder_decoder:
|
||||||
if (
|
if (
|
||||||
generation_config.pad_token_id is not None
|
generation_config.pad_token_id is not None
|
||||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||||
> 0
|
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||||
|
@ -247,10 +222,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_seq_length = input_ids.shape[-1]
|
input_ids_seq_length = input_ids.shape[-1]
|
||||||
has_default_max_length = (
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
kwargs.get("max_length") is None
|
|
||||||
and generation_config.max_length is not None
|
|
||||||
)
|
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
|
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
|
||||||
|
@ -260,12 +232,8 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
elif has_default_max_length and generation_config.max_new_tokens is not None:
|
||||||
generation_config.max_length = (
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||||
generation_config.max_new_tokens + input_ids_seq_length
|
elif not has_default_max_length and generation_config.max_new_tokens is not None:
|
||||||
)
|
|
||||||
elif (
|
|
||||||
not has_default_max_length and generation_config.max_new_tokens is not None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
||||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
||||||
|
@ -273,18 +241,13 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||||
generation_config.min_length is not None
|
|
||||||
and generation_config.min_length > generation_config.max_length
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
||||||
f" the maximum length ({generation_config.max_length})"
|
f" the maximum length ({generation_config.max_length})"
|
||||||
)
|
)
|
||||||
if input_ids_seq_length >= generation_config.max_length:
|
if input_ids_seq_length >= generation_config.max_length:
|
||||||
input_ids_string = (
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
|
||||||
)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||||
|
@ -293,8 +256,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# 7. determine generation mode
|
# 7. determine generation mode
|
||||||
is_constraint_gen_mode = (
|
is_constraint_gen_mode = (
|
||||||
generation_config.constraints is not None
|
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
||||||
or generation_config.force_words_ids is not None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
is_contrastive_search_gen_mode = (
|
is_contrastive_search_gen_mode = (
|
||||||
|
@ -349,9 +311,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.num_beam_groups > generation_config.num_beams:
|
if generation_config.num_beam_groups > generation_config.num_beams:
|
||||||
raise ValueError(
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||||
"`num_beam_groups` has to be smaller or equal to `num_beams`"
|
|
||||||
)
|
|
||||||
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
if is_group_beam_gen_mode and generation_config.do_sample is True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||||
|
@ -474,14 +434,10 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
)
|
)
|
||||||
elif is_beam_gen_mode:
|
elif is_beam_gen_mode:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError(
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError(
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
"`max_length` needs to be a stopping_criteria for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 11. prepare beam search scorer
|
# 11. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
|
@ -518,9 +474,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
logits_warper = self._get_logits_warper(generation_config)
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError(
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
"`max_length` needs to be a stopping_criteria for now."
|
|
||||||
)
|
|
||||||
# 12. prepare beam search scorer
|
# 12. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size * generation_config.num_return_sequences,
|
batch_size=batch_size * generation_config.num_return_sequences,
|
||||||
|
@ -533,8 +487,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_beams
|
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
|
||||||
* generation_config.num_return_sequences,
|
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -556,27 +509,17 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_group_beam_gen_mode:
|
elif is_group_beam_gen_mode:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError(
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
||||||
raise ValueError(
|
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
|
||||||
"`num_beams` should be divisible by `num_beam_groups` for group beam search."
|
|
||||||
)
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError(
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
"`max_length` needs to be a stopping_criteria for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
has_default_typical_p = (
|
has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
||||||
kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
|
|
||||||
)
|
|
||||||
if not has_default_typical_p:
|
if not has_default_typical_p:
|
||||||
raise ValueError(
|
raise ValueError("Decoder argument `typical_p` is not supported with beam groups.")
|
||||||
"Decoder argument `typical_p` is not supported with beam groups."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 11. prepare beam search scorer
|
# 11. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
|
@ -612,32 +555,19 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
elif is_constraint_gen_mode:
|
elif is_constraint_gen_mode:
|
||||||
if generation_config.num_return_sequences > generation_config.num_beams:
|
if generation_config.num_return_sequences > generation_config.num_beams:
|
||||||
raise ValueError(
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||||
"`num_return_sequences` has to be smaller or equal to `num_beams`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if stopping_criteria.max_length is None:
|
if stopping_criteria.max_length is None:
|
||||||
raise ValueError(
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||||
"`max_length` needs to be a stopping_criteria for now."
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config.num_beams <= 1:
|
if generation_config.num_beams <= 1:
|
||||||
raise ValueError(
|
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")
|
||||||
"`num_beams` needs to be greater than 1 for constrained generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config.do_sample:
|
if generation_config.do_sample:
|
||||||
raise ValueError(
|
raise ValueError("`do_sample` needs to be false for constrained generation.")
|
||||||
"`do_sample` needs to be false for constrained generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1:
|
||||||
generation_config.num_beam_groups is not None
|
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
|
||||||
and generation_config.num_beam_groups > 1
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"`num_beam_groups` not supported yet for constrained generation."
|
|
||||||
)
|
|
||||||
|
|
||||||
final_constraints = []
|
final_constraints = []
|
||||||
if generation_config.constraints is not None:
|
if generation_config.constraints is not None:
|
||||||
|
@ -661,15 +591,10 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
if isinstance(word_ids[0], list):
|
if isinstance(word_ids[0], list):
|
||||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||||
typeerror()
|
typeerror()
|
||||||
if any(
|
if any(not isinstance(token_ids, list) for token_ids in word_ids):
|
||||||
not isinstance(token_ids, list) for token_ids in word_ids
|
|
||||||
):
|
|
||||||
typeerror()
|
typeerror()
|
||||||
if any(
|
if any(
|
||||||
any(
|
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
|
||||||
(not isinstance(token_id, int) or token_id < 0)
|
|
||||||
for token_id in token_ids
|
|
||||||
)
|
|
||||||
for token_ids in word_ids
|
for token_ids in word_ids
|
||||||
):
|
):
|
||||||
typeerror()
|
typeerror()
|
||||||
|
@ -678,10 +603,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
else:
|
else:
|
||||||
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
if not isinstance(word_ids, list) or len(word_ids) == 0:
|
||||||
typeerror()
|
typeerror()
|
||||||
if any(
|
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
|
||||||
(not isinstance(token_id, int) or token_id < 0)
|
|
||||||
for token_id in word_ids
|
|
||||||
):
|
|
||||||
typeerror()
|
typeerror()
|
||||||
|
|
||||||
constraint = PhrasalConstraint(word_ids)
|
constraint = PhrasalConstraint(word_ids)
|
||||||
|
@ -843,52 +765,26 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
|
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
|
||||||
```"""
|
```"""
|
||||||
# init values
|
# init values
|
||||||
logits_processor = (
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
logits_processor if logits_processor is not None else LogitsProcessorList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
)
|
|
||||||
stopping_criteria = (
|
|
||||||
stopping_criteria
|
|
||||||
if stopping_criteria is not None
|
|
||||||
else StoppingCriteriaList()
|
|
||||||
)
|
|
||||||
if max_length is not None:
|
if max_length is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`max_length` is deprecated in this function, use"
|
"`max_length` is deprecated in this function, use"
|
||||||
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
stopping_criteria = validate_stopping_criteria(
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
stopping_criteria, max_length
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
)
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
logits_warper = (
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
||||||
)
|
|
||||||
pad_token_id = (
|
|
||||||
pad_token_id
|
|
||||||
if pad_token_id is not None
|
|
||||||
else self.generation_config.pad_token_id
|
|
||||||
)
|
|
||||||
eos_token_id = (
|
|
||||||
eos_token_id
|
|
||||||
if eos_token_id is not None
|
|
||||||
else self.generation_config.eos_token_id
|
|
||||||
)
|
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
output_scores = (
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
output_scores
|
|
||||||
if output_scores is not None
|
|
||||||
else self.generation_config.output_scores
|
|
||||||
)
|
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
if output_attentions is not None
|
|
||||||
else self.generation_config.output_attentions
|
|
||||||
)
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.generation_config.output_hidden_states
|
|
||||||
)
|
)
|
||||||
return_dict_in_generate = (
|
return_dict_in_generate = (
|
||||||
return_dict_in_generate
|
return_dict_in_generate
|
||||||
|
@ -898,15 +794,9 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = (
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
() if (return_dict_in_generate and output_attentions) else None
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
)
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
cross_attentions = (
|
|
||||||
() if (return_dict_in_generate and output_attentions) else None
|
|
||||||
)
|
|
||||||
decoder_hidden_states = (
|
|
||||||
() if (return_dict_in_generate and output_hidden_states) else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
|
@ -917,9 +807,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
if synced_gpus:
|
if synced_gpus:
|
||||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
# The following logic allows an early break if all peers finished generating their sequence
|
# The following logic allows an early break if all peers finished generating their sequence
|
||||||
this_peer_finished_flag = torch.tensor(
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||||
0.0 if this_peer_finished else 1.0
|
|
||||||
).to(input_ids.device)
|
|
||||||
# send 0.0 if we finished, 1.0 otherwise
|
# send 0.0 if we finished, 1.0 otherwise
|
||||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||||
# did all peers finish? the reduced sum will be 0.0 then
|
# did all peers finish? the reduced sum will be 0.0 then
|
||||||
|
@ -952,18 +840,14 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
scores += (next_token_scores,)
|
scores += (next_token_scores,)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,)
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
if self.config.is_encoder_decoder
|
|
||||||
else (outputs.attentions,)
|
|
||||||
)
|
)
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (
|
||||||
(outputs.decoder_hidden_states,)
|
(outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
|
||||||
if self.config.is_encoder_decoder
|
|
||||||
else (outputs.hidden_states,)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
|
@ -973,12 +857,8 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError(
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
)
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
|
||||||
1 - unfinished_sequences
|
|
||||||
)
|
|
||||||
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
|
yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
@ -988,9 +868,7 @@ class NewGenerationMixin(GenerationMixin):
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||||||
(sum(next_tokens != i for i in eos_token_id)).long()
|
|
||||||
)
|
|
||||||
|
|
||||||
# stop when each sentence is finished, or if we exceed the maximum length
|
# stop when each sentence is finished, or if we exceed the maximum length
|
||||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||||
|
@ -1007,22 +885,17 @@ def init_stream_support():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from transformers import PreTrainedModel
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
||||||
|
|
||||||
PreTrainedModel.generate = NewGenerationMixin.generate
|
PreTrainedModel.generate = NewGenerationMixin.generate
|
||||||
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16)
|
||||||
"bigscience/bloom-560m", torch_dtype=torch.float16
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
|
||||||
model = model.to("cuda:0")
|
model = model.to("cuda:0")
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
prompt_text = "hello? \n"
|
prompt_text = "hello? \n"
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids
|
||||||
prompt_text, return_tensors="pt", add_special_tokens=False
|
|
||||||
).input_ids
|
|
||||||
input_ids = input_ids.to("cuda:0")
|
input_ids = input_ids.to("cuda:0")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
|
||||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||||
|
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def key_samples_by_col(samples, col):
|
def key_samples_by_col(samples, col):
|
||||||
"""Returns a dictionary of samples keyed by language."""
|
"""Returns a dictionary of samples keyed by language."""
|
||||||
samples_by_col = {}
|
samples_by_col = {}
|
||||||
|
@ -50,7 +52,7 @@ def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate,
|
||||||
|
|
||||||
def load_audio(audiopath, sampling_rate):
|
def load_audio(audiopath, sampling_rate):
|
||||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||||
if audiopath[-4:] == '.mp3':
|
if audiopath[-4:] == ".mp3":
|
||||||
# it uses torchaudio with sox backend to load mp3
|
# it uses torchaudio with sox backend to load mp3
|
||||||
audio, lsr = torchaudio_sox_load(audiopath)
|
audio, lsr = torchaudio_sox_load(audiopath)
|
||||||
else:
|
else:
|
||||||
|
@ -72,6 +74,7 @@ def load_audio(audiopath, sampling_rate):
|
||||||
audio.clip_(-1, 1)
|
audio.clip_(-1, 1)
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
class XTTSDataset(torch.utils.data.Dataset):
|
class XTTSDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -108,9 +111,11 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
if wav is None or \
|
if (
|
||||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
wav is None
|
||||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||||
|
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
new_samples.append(sample)
|
new_samples.append(sample)
|
||||||
self.samples = new_samples
|
self.samples = new_samples
|
||||||
|
@ -125,9 +130,9 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def load_item(self, sample):
|
def load_item(self, sample):
|
||||||
text = str(sample['text'])
|
text = str(sample["text"])
|
||||||
tseq = self.get_text(text, sample["language"])
|
tseq = self.get_text(text, sample["language"])
|
||||||
audiopath = sample['audio_file']
|
audiopath = sample["audio_file"]
|
||||||
wav = load_audio(audiopath, self.sample_rate)
|
wav = load_audio(audiopath, self.sample_rate)
|
||||||
if text is None or len(text.strip()) == 0:
|
if text is None or len(text.strip()) == 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
@ -136,7 +141,9 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
# get a slice from GT to condition the model
|
# get a slice from GT to condition the model
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
|
cond, cond_len, cond_idxs = get_prompt_slice(
|
||||||
|
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
|
)
|
||||||
|
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
|
@ -170,26 +177,30 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
return self[1]
|
return self[1]
|
||||||
|
|
||||||
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
||||||
if wav is None or \
|
if (
|
||||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
wav is None
|
||||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
||||||
|
or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len)
|
||||||
|
):
|
||||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||||
if self.debug_failures and wav is not None and tseq is not None:
|
if self.debug_failures and wav is not None and tseq is not None:
|
||||||
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
print(
|
||||||
|
f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}"
|
||||||
|
)
|
||||||
self.failed_samples.add(sample_id)
|
self.failed_samples.add(sample_id)
|
||||||
return self[1]
|
return self[1]
|
||||||
|
|
||||||
res = {
|
res = {
|
||||||
# 'real_text': text,
|
# 'real_text': text,
|
||||||
'text': tseq,
|
"text": tseq,
|
||||||
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
|
"text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||||
'wav': wav,
|
"wav": wav,
|
||||||
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
|
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||||
'filenames': audiopath,
|
"filenames": audiopath,
|
||||||
'conditioning': cond.unsqueeze(1),
|
"conditioning": cond.unsqueeze(1),
|
||||||
'cond_lens': torch.tensor(cond_len, dtype=torch.long),
|
"cond_lens": torch.tensor(cond_len, dtype=torch.long),
|
||||||
'cond_idxs': torch.tensor(cond_idxs),
|
"cond_idxs": torch.tensor(cond_idxs),
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -223,7 +234,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
for i in range(B):
|
for i in range(B):
|
||||||
text = batch["text"][i]
|
text = batch["text"][i]
|
||||||
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
|
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
|
||||||
wav = batch['wav'][i]
|
wav = batch["wav"][i]
|
||||||
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
||||||
|
|
||||||
batch["wav"] = wav_padded
|
batch["wav"] = wav_padded
|
||||||
|
|
|
@ -1,37 +1,30 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from coqpit import Coqpit
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
|
||||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
|
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
|
||||||
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
|
||||||
from coqpit import Coqpit
|
|
||||||
|
|
||||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
|
||||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
|
||||||
|
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
|
||||||
|
|
||||||
from trainer.torch import DistributedSampler
|
from trainer.torch import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
|
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||||
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
|
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||||
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTTrainerConfig(XttsConfig):
|
class GPTTrainerConfig(XttsConfig):
|
||||||
|
@ -42,6 +35,7 @@ class GPTTrainerConfig(XttsConfig):
|
||||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||||
test_sentences: List[dict] = field(default_factory=lambda: [])
|
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XttsAudioConfig(XttsAudioConfig):
|
class XttsAudioConfig(XttsAudioConfig):
|
||||||
dvae_sample_rate: int = 22050
|
dvae_sample_rate: int = 22050
|
||||||
|
@ -68,14 +62,15 @@ class GPTArgs(XttsArgs):
|
||||||
def callback_clearml_load_save(operation_type, model_info):
|
def callback_clearml_load_save(operation_type, model_info):
|
||||||
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
||||||
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
||||||
assert operation_type in ('load', 'save')
|
assert operation_type in ("load", "save")
|
||||||
# print(operation_type, model_info.__dict__)
|
# print(operation_type, model_info.__dict__)
|
||||||
|
|
||||||
if "similarities.pth" in model_info.__dict__['local_model_path']:
|
if "similarities.pth" in model_info.__dict__["local_model_path"]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
|
|
||||||
class GPTTrainer(BaseTTS):
|
class GPTTrainer(BaseTTS):
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
"""
|
"""
|
||||||
|
@ -89,18 +84,17 @@ class GPTTrainer(BaseTTS):
|
||||||
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||||
# init gpt encoder and hifigan decoder
|
# init gpt encoder and hifigan decoder
|
||||||
self.xtts.init_models()
|
self.xtts.init_models()
|
||||||
# set mel stats
|
|
||||||
if self.args.mel_norm_file:
|
|
||||||
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
|
||||||
|
|
||||||
if self.args.xtts_checkpoint:
|
if self.args.xtts_checkpoint:
|
||||||
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
||||||
|
|
||||||
|
# set mel stats
|
||||||
|
if self.args.mel_norm_file:
|
||||||
|
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||||
|
|
||||||
# load GPT if available
|
# load GPT if available
|
||||||
if self.args.gpt_checkpoint:
|
if self.args.gpt_checkpoint:
|
||||||
gpt_checkpoint = torch.load(
|
gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu"))
|
||||||
self.args.gpt_checkpoint, map_location=torch.device("cpu")
|
|
||||||
)
|
|
||||||
# deal with coqui Trainer exported model
|
# deal with coqui Trainer exported model
|
||||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||||
print("Coqui Trainer checkpoint detected! Converting it!")
|
print("Coqui Trainer checkpoint detected! Converting it!")
|
||||||
|
@ -115,8 +109,13 @@ class GPTTrainer(BaseTTS):
|
||||||
del gpt_checkpoint[key]
|
del gpt_checkpoint[key]
|
||||||
|
|
||||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
|
if (
|
||||||
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
"text_embedding.weight" in gpt_checkpoint
|
||||||
|
and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
|
||||||
|
):
|
||||||
|
num_new_tokens = (
|
||||||
|
self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||||
|
)
|
||||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||||
|
|
||||||
# add new tokens to a linear layer (text_head)
|
# add new tokens to a linear layer (text_head)
|
||||||
|
@ -156,7 +155,7 @@ class GPTTrainer(BaseTTS):
|
||||||
mel_fmin=0,
|
mel_fmin=0,
|
||||||
mel_fmax=8000,
|
mel_fmax=8000,
|
||||||
n_mel_channels=80,
|
n_mel_channels=80,
|
||||||
mel_norm_file=self.args.mel_norm_file
|
mel_norm_file=self.args.mel_norm_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load DVAE
|
# Load DVAE
|
||||||
|
@ -175,17 +174,18 @@ class GPTTrainer(BaseTTS):
|
||||||
|
|
||||||
self.dvae.eval()
|
self.dvae.eval()
|
||||||
if self.args.dvae_checkpoint:
|
if self.args.dvae_checkpoint:
|
||||||
dvae_checkpoint = torch.load(
|
dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu"))
|
||||||
self.args.dvae_checkpoint, map_location=torch.device("cpu")
|
|
||||||
)
|
|
||||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||||
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!")
|
raise RuntimeError(
|
||||||
|
"You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!"
|
||||||
|
)
|
||||||
|
|
||||||
# Mel spectrogram extractor for DVAE
|
# Mel spectrogram extractor for DVAE
|
||||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(
|
||||||
|
mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
|
@ -203,7 +203,9 @@ class GPTTrainer(BaseTTS):
|
||||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||||
cond_idxs: cond start and end indexs, (b, 2)
|
cond_idxs: cond start and end indexs, (b, 2)
|
||||||
"""
|
"""
|
||||||
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
losses = self.xtts.gpt(
|
||||||
|
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
||||||
|
)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -215,7 +217,9 @@ class GPTTrainer(BaseTTS):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
print(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
for idx, s_info in enumerate(self.config.test_sentences):
|
for idx, s_info in enumerate(self.config.test_sentences):
|
||||||
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
|
wav = self.xtts.synthesize(
|
||||||
|
s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3
|
||||||
|
)["wav"]
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
|
||||||
# delete inference layers
|
# delete inference layers
|
||||||
|
@ -300,6 +304,7 @@ class GPTTrainer(BaseTTS):
|
||||||
# ignore similarities.pth on clearml save/upload
|
# ignore similarities.pth on clearml save/upload
|
||||||
if self.config.dashboard_logger.lower() == "clearml":
|
if self.config.dashboard_logger.lower() == "clearml":
|
||||||
from clearml.binding.frameworks import WeightsFileHandler
|
from clearml.binding.frameworks import WeightsFileHandler
|
||||||
|
|
||||||
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
|
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -367,16 +372,23 @@ class GPTTrainer(BaseTTS):
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def get_optimizer(self) -> List:
|
def get_optimizer(self) -> List:
|
||||||
"""Initiate and return the optimizer based on the config parameters.
|
"""Initiate and return the optimizer based on the config parameters."""
|
||||||
"""
|
|
||||||
# ToDo: deal with multi GPU training
|
# ToDo: deal with multi GPU training
|
||||||
if self.config.optimizer_wd_only_on_weights:
|
if self.config.optimizer_wd_only_on_weights:
|
||||||
# parameters to only GPT model
|
# parameters to only GPT model
|
||||||
net = self.xtts.gpt
|
net = self.xtts.gpt
|
||||||
|
|
||||||
# normalizations
|
# normalizations
|
||||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
norm_modules = (
|
||||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
nn.BatchNorm2d,
|
||||||
|
nn.InstanceNorm2d,
|
||||||
|
nn.BatchNorm1d,
|
||||||
|
nn.InstanceNorm1d,
|
||||||
|
nn.BatchNorm3d,
|
||||||
|
nn.InstanceNorm3d,
|
||||||
|
nn.GroupNorm,
|
||||||
|
nn.LayerNorm,
|
||||||
|
)
|
||||||
# nn.Embedding
|
# nn.Embedding
|
||||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||||
|
|
||||||
|
@ -390,7 +402,7 @@ class GPTTrainer(BaseTTS):
|
||||||
v.is_norm = isinstance(m, norm_modules)
|
v.is_norm = isinstance(m, norm_modules)
|
||||||
v.is_emb = isinstance(m, emb_modules)
|
v.is_emb = isinstance(m, emb_modules)
|
||||||
|
|
||||||
fpn = '%s.%s' % (mn, k) if mn else k # full param name
|
fpn = "%s.%s" % (mn, k) if mn else k # full param name
|
||||||
all_param_names.add(fpn)
|
all_param_names.add(fpn)
|
||||||
param_map[fpn] = v
|
param_map[fpn] = v
|
||||||
if v.is_bias or v.is_norm or v.is_emb:
|
if v.is_bias or v.is_norm or v.is_emb:
|
||||||
|
@ -402,8 +414,8 @@ class GPTTrainer(BaseTTS):
|
||||||
params_weights = [param_map[k] for k in params_names_weights]
|
params_weights = [param_map[k] for k in params_names_weights]
|
||||||
|
|
||||||
groups = [
|
groups = [
|
||||||
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]},
|
{"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]},
|
||||||
{ 'params': params_notweights, 'weight_decay': 0}
|
{"params": params_notweights, "weight_decay": 0},
|
||||||
]
|
]
|
||||||
# torch.optim.AdamW
|
# torch.optim.AdamW
|
||||||
opt = get_optimizer(
|
opt = get_optimizer(
|
||||||
|
@ -461,4 +473,3 @@ class GPTTrainer(BaseTTS):
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
return GPTTrainer(config)
|
return GPTTrainer(config)
|
||||||
|
|
|
@ -11,15 +11,16 @@ from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to
|
||||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
|
||||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
|
||||||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||||
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
|
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
init_stream_support()
|
init_stream_support()
|
||||||
|
|
||||||
|
|
||||||
def load_audio(audiopath, sr=22050):
|
def load_audio(audiopath, sr=22050):
|
||||||
"""
|
"""
|
||||||
Load an audio file from disk and resample it to the specified sampling rate.
|
Load an audio file from disk and resample it to the specified sampling rate.
|
||||||
|
@ -332,7 +333,6 @@ class Xtts(BaseTTS):
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.args.use_hifigan:
|
if self.args.use_hifigan:
|
||||||
self.hifigan_decoder = HifiDecoder(
|
self.hifigan_decoder = HifiDecoder(
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
|
@ -414,14 +414,13 @@ class Xtts(BaseTTS):
|
||||||
return diffusion_latent
|
return diffusion_latent
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_speaker_embedding(
|
def get_speaker_embedding(self, audio_path):
|
||||||
self,
|
|
||||||
audio_path
|
|
||||||
):
|
|
||||||
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
|
||||||
speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
|
speaker_embedding = (
|
||||||
audio.to(self.device), l2_norm=True
|
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
|
||||||
).unsqueeze(-1).to(self.device)
|
.unsqueeze(-1)
|
||||||
|
.to(self.device)
|
||||||
|
)
|
||||||
return speaker_embedding
|
return speaker_embedding
|
||||||
|
|
||||||
def get_conditioning_latents(
|
def get_conditioning_latents(
|
||||||
|
@ -563,11 +562,9 @@ class Xtts(BaseTTS):
|
||||||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||||
Sample rate is 24kHz.
|
Sample rate is 24kHz.
|
||||||
"""
|
"""
|
||||||
(
|
(gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents(
|
||||||
gpt_cond_latent,
|
audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len
|
||||||
diffusion_conditioning,
|
)
|
||||||
speaker_embedding
|
|
||||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
|
||||||
return self.inference(
|
return self.inference(
|
||||||
text,
|
text,
|
||||||
language,
|
language,
|
||||||
|
@ -721,7 +718,9 @@ class Xtts(BaseTTS):
|
||||||
decoder="hifigan",
|
decoder="hifigan",
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
assert hasattr(
|
||||||
|
self, "hifigan_decoder"
|
||||||
|
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = f"[{language}]{text.strip().lower()}"
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -827,6 +826,15 @@ class Xtts(BaseTTS):
|
||||||
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||||
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
|
||||||
for key in list(checkpoint.keys()):
|
for key in list(checkpoint.keys()):
|
||||||
|
# check if it is from the coqui Trainer if so convert it
|
||||||
|
if key.startswith("xtts."):
|
||||||
|
coqui_trainer_checkpoint = True
|
||||||
|
new_key = key.replace("xtts.", "")
|
||||||
|
checkpoint[new_key] = checkpoint[key]
|
||||||
|
del checkpoint[key]
|
||||||
|
key = new_key
|
||||||
|
|
||||||
|
# remove unused keys
|
||||||
if key.split(".")[0] in ignore_keys:
|
if key.split(".")[0] in ignore_keys:
|
||||||
del checkpoint[key]
|
del checkpoint[key]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
|
|
||||||
|
# Define here the dataset used
|
||||||
|
config_ljspeech = BaseDatasetConfig(
|
||||||
|
formatter="ljspeech",
|
||||||
|
dataset_name="ljspeech",
|
||||||
|
path="/raid/datasets/LJSpeech-1.1_24khz/",
|
||||||
|
meta_file_train="/raid/datasets/LJSpeech-1.1_24khz/metadata.csv",
|
||||||
|
language="en",
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASETS_CONFIG_LIST = [config_ljspeech]
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_layers(trainer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# init args and config
|
||||||
|
model_args = GPTArgs(
|
||||||
|
max_conditioning_length=132300, # 6 secs
|
||||||
|
min_conditioning_length=66150, # 3 secs
|
||||||
|
debug_loading_failures=False,
|
||||||
|
max_wav_length=255995, # ~11.6 seconds
|
||||||
|
max_text_length=200,
|
||||||
|
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||||
|
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||||
|
# tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||||
|
# xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth",
|
||||||
|
xtts_checkpoint="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/132500_gpt_ema_coqui_tts_with_enhanced_hifigan.pth", # checkpoint path of the model that you want to fine-tune
|
||||||
|
tokenizer_file="/raid/edresson/dev/Checkpoints/XTTS_evaluation/xtts_style_emb_repetition_fix_gt/tokenizer_merged_5.json",
|
||||||
|
gpt_num_audio_tokens=8194,
|
||||||
|
gpt_start_audio_token=8192,
|
||||||
|
gpt_stop_audio_token=8193,
|
||||||
|
)
|
||||||
|
audio_config = XttsAudioConfig(
|
||||||
|
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
|
||||||
|
)
|
||||||
|
config = GPTTrainerConfig(
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model_args=model_args,
|
||||||
|
run_name=RUN_NAME,
|
||||||
|
project_name=PROJECT_NAME,
|
||||||
|
run_description="""
|
||||||
|
GPT XTTS training
|
||||||
|
""",
|
||||||
|
dashboard_logger=DASHBOARD_LOGGER,
|
||||||
|
logger_uri=LOGGER_URI,
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
batch_group_size=48,
|
||||||
|
eval_batch_size=BATCH_SIZE,
|
||||||
|
num_loader_workers=8,
|
||||||
|
eval_split_max_size=256,
|
||||||
|
print_step=50,
|
||||||
|
plot_step=100,
|
||||||
|
log_model_step=1000,
|
||||||
|
save_step=10000,
|
||||||
|
save_n_checkpoints=1,
|
||||||
|
save_checkpoints=True,
|
||||||
|
# target_loss="loss",
|
||||||
|
print_eval=False,
|
||||||
|
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||||
|
optimizer="AdamW",
|
||||||
|
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
|
||||||
|
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||||
|
lr=5e-06, # learning rate
|
||||||
|
lr_scheduler="MultiStepLR",
|
||||||
|
# it was adjusted accordly for the new step scheme
|
||||||
|
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||||
|
test_sentences=[
|
||||||
|
{
|
||||||
|
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
|
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||||
|
"language": "en",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "This cake is great. It's so delicious and moist.",
|
||||||
|
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||||
|
"language": "en",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Levei muito tempo para desenvolver uma voz e agora que a tenho não vou ficar calado .",
|
||||||
|
"speaker_wav": "/raid/edresson/dev/ref-ljspeech.wav",
|
||||||
|
"language": "pt",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the model from config
|
||||||
|
model = GPTTrainer.init_from_config(config)
|
||||||
|
|
||||||
|
# load training samples
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
DATASETS_CONFIG_LIST,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainerArgs(
|
||||||
|
restore_path=RESTORE_PATH,
|
||||||
|
skip_train_epoch=SKIP_TRAIN_EPOCH,
|
||||||
|
start_with_eval=START_WITH_EVAL,
|
||||||
|
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||||
|
),
|
||||||
|
config,
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
callbacks={"on_epoch_start": freeze_layers},
|
||||||
|
)
|
||||||
|
trainer.fit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
RUN_NAME = "GPT_XTTS_LJSpeech_fixed"
|
||||||
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
|
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_v1_FT/"
|
||||||
|
# DASHBOARD_LOGGER = "clearml"
|
||||||
|
# LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_v1/"
|
||||||
|
DASHBOARD_LOGGER = "tensorboard"
|
||||||
|
LOGGER_URI = None
|
||||||
|
RESTORE_PATH = None
|
||||||
|
SKIP_TRAIN_EPOCH = False
|
||||||
|
START_WITH_EVAL = True
|
||||||
|
BATCH_SIZE = 3
|
||||||
|
GRAD_ACUMM_STEPS = 28 * 3
|
||||||
|
|
||||||
|
# debug
|
||||||
|
# DASHBOARD_LOGGER = "tensorboard"
|
||||||
|
# LOGGER_URI = None
|
||||||
|
# RESTORE_PATH = None
|
||||||
|
# BATCH_SIZE = 2
|
||||||
|
# GRAD_ACUMM_STEPS = 1
|
||||||
|
|
||||||
|
main()
|
|
@ -2,9 +2,7 @@ from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
|
|
||||||
|
|
||||||
|
|
||||||
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
||||||
formatter="coqui",
|
formatter="coqui",
|
||||||
|
@ -252,11 +250,16 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
|
||||||
|
|
||||||
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||||
|
|
||||||
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
DATASETS_CONFIG_LIST = [
|
||||||
|
config_coqui_MLS_metadata_test_with_previous_audio_key_de,
|
||||||
|
config_coqui_mls_italian_metadata_with_previous_audio_key_it,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def freeze_layers(trainer):
|
def freeze_layers(trainer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# init args and config
|
# init args and config
|
||||||
model_args = GPTArgs(
|
model_args = GPTArgs(
|
||||||
|
@ -274,10 +277,7 @@ def main():
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
)
|
)
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(
|
||||||
sample_rate=22050, # GPT SR
|
sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 # GPT SR
|
||||||
dvae_sample_rate=22050,
|
|
||||||
diffusion_sample_rate=24000,
|
|
||||||
output_sample_rate=24000
|
|
||||||
)
|
)
|
||||||
config = GPTTrainerConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
|
@ -303,20 +303,26 @@ def main():
|
||||||
save_checkpoints=True,
|
save_checkpoints=True,
|
||||||
# target_loss="loss",
|
# target_loss="loss",
|
||||||
print_eval=False,
|
print_eval=False,
|
||||||
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch
|
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
||||||
optimizer="AdamW",
|
optimizer="AdamW",
|
||||||
optimizer_wd_only_on_weights=True,
|
optimizer_wd_only_on_weights=True, # for multi-gpu training turn it off
|
||||||
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2},
|
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||||
lr=5e-06, # learning rate
|
lr=5e-06, # learning rate
|
||||||
# lr=1e-4, # learning rate
|
|
||||||
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
|
|
||||||
lr_scheduler="MultiStepLR",
|
lr_scheduler="MultiStepLR",
|
||||||
# it was adjusted accordly for the new step scheme
|
# it was adjusted accordly for the new step scheme
|
||||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
{
|
||||||
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||||
]
|
"speaker_wav": "/raid/edresson/dev/ref.wav",
|
||||||
|
"language": "en",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "This cake is great. It's so delicious and moist.",
|
||||||
|
"speaker_wav": "/raid/edresson/dev/ref.wav",
|
||||||
|
"language": "en",
|
||||||
|
},
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# init the model from config
|
# init the model from config
|
||||||
|
@ -332,13 +338,18 @@ def main():
|
||||||
|
|
||||||
# init the trainer and 🚀
|
# init the trainer and 🚀
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS),
|
TrainerArgs(
|
||||||
|
restore_path=RESTORE_PATH,
|
||||||
|
skip_train_epoch=SKIP_TRAIN_EPOCH,
|
||||||
|
start_with_eval=START_WITH_EVAL,
|
||||||
|
grad_accum_steps=GRAD_ACUMM_STEPS,
|
||||||
|
),
|
||||||
config,
|
config,
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
model=model,
|
model=model,
|
||||||
train_samples=train_samples,
|
train_samples=train_samples,
|
||||||
eval_samples=eval_samples,
|
eval_samples=eval_samples,
|
||||||
callbacks={"on_epoch_start": freeze_layers}
|
callbacks={"on_epoch_start": freeze_layers},
|
||||||
)
|
)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
||||||
|
@ -362,6 +373,4 @@ if __name__ == "__main__":
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
GRAD_ACUMM_STEPS = 1
|
GRAD_ACUMM_STEPS = 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -99,6 +99,7 @@ def test_xtts_streaming():
|
||||||
"""Testing the new inference_stream method"""
|
"""Testing the new inference_stream method"""
|
||||||
from TTS.tts.configs.xtts_config import XttsConfig
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
from TTS.tts.models.xtts import Xtts
|
from TTS.tts.models.xtts import Xtts
|
||||||
|
|
||||||
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||||
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
|
||||||
config = XttsConfig()
|
config = XttsConfig()
|
||||||
|
@ -115,7 +116,7 @@ def test_xtts_streaming():
|
||||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||||
"en",
|
"en",
|
||||||
gpt_cond_latent,
|
gpt_cond_latent,
|
||||||
speaker_embedding
|
speaker_embedding,
|
||||||
)
|
)
|
||||||
wav_chuncks = []
|
wav_chuncks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
|
|
Loading…
Reference in New Issue