mirror of https://github.com/coqui-ai/TTS.git
Clean up old code
This commit is contained in:
parent
66e3f5388e
commit
a543d71352
|
@ -13,9 +13,8 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
|||
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
|
||||
|
||||
Args:
|
||||
<<<<<<< HEAD
|
||||
items (List[List]):
|
||||
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
|
||||
A list of samples. Each sample is a dict containing the keys "text", "audio_file", and "speaker_name".
|
||||
|
||||
eval_split_max_size (int):
|
||||
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||
|
@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
|
|||
eval_split_size (float):
|
||||
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||
=======
|
||||
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
|
||||
>>>>>>> Fix docstring
|
||||
"""
|
||||
speakers = [item["speaker_name"] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
|
@ -117,7 +113,9 @@ def load_tts_samples(
|
|||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
|
||||
meta_data_eval = [
|
||||
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
|
||||
]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# pylint: disable=W0223
|
||||
class GradientReversalFunction(torch.autograd.Function):
|
||||
"""Revert gradient without any further input modification.
|
||||
"""Revert gradient without any further input modification.
|
||||
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
|
||||
|
||||
@staticmethod
|
||||
|
@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
||||
super(ReversalClassifier, self).__init__()
|
||||
super().__init__()
|
||||
self._lambda = scale_factor
|
||||
self._clipping = gradient_clipping_bounds
|
||||
self._out_channels = out_channels
|
||||
self._classifier = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels, out_channels)
|
||||
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
||||
)
|
||||
self.test = nn.Linear(in_channels, hidden_channels)
|
||||
def forward(self, x, labels, x_mask=None):
|
||||
|
||||
def forward(self, x, labels, x_mask=None):
|
||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||
x = self._classifier(x)
|
||||
loss = self.loss(labels, x, x_mask)
|
||||
|
@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module):
|
|||
ml = torch.max(x_mask)
|
||||
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
|
||||
|
||||
target = labels.repeat(ml.int().item(), 1).transpose(0,1)
|
||||
target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
|
||||
target[~input_mask] = ignore_index
|
||||
|
||||
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index)
|
||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||
|
|
|
@ -13,9 +13,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
|
|
@ -223,8 +223,7 @@ class Tacotron(BaseTacotron):
|
|||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
<<<<<<< HEAD
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
if aux_input["style_text"] is not None:
|
||||
style_text_embedding = self.embedding(aux_input["style_text"])
|
||||
|
@ -232,24 +231,21 @@ class Tacotron(BaseTacotron):
|
|||
encoder_outputs.device
|
||||
) # pylint: disable=not-callable
|
||||
reference_mel_length = (
|
||||
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_mel"] is not None
|
||||
torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_feature"] is not None
|
||||
else None
|
||||
) # pylint: disable=not-callable
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
|
||||
if aux_input["style_mel"] is not None
|
||||
reference_mel_info=[aux_input["style_feature"], reference_mel_length]
|
||||
if aux_input["style_feature"] is not None
|
||||
else None,
|
||||
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
|
||||
speaker_embedding=aux_input["d_vectors"]
|
||||
if self.capacitron_vae.capacitron_use_speaker_embedding
|
||||
else None,
|
||||
)
|
||||
=======
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
|
||||
>>>>>>> 3a524b05... Add prosody encoder params on config
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vector_file:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
|
|
@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
|
|||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
|
||||
##############################
|
||||
# IO / Feature extraction
|
||||
##############################
|
||||
|
@ -541,8 +540,7 @@ class VitsArgs(Coqpit):
|
|||
external_emotions_embs_file: str = None
|
||||
emotion_embedding_dim: int = 0
|
||||
num_emotions: int = 0
|
||||
emotion_just_encoder: bool = False
|
||||
|
||||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
prosody_embedding_dim: int = 0
|
||||
|
@ -658,7 +656,7 @@ class Vits(BaseTTS):
|
|||
|
||||
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
|
||||
|
||||
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
|
@ -872,12 +870,6 @@ class Vits(BaseTTS):
|
|||
if self.num_emotions > 0:
|
||||
print(" > initialization of emotion-embedding layers.")
|
||||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||
if not self.args.emotion_just_encoder:
|
||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_external_emotions_embeddings:
|
||||
if not self.args.emotion_just_encoder:
|
||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||
|
@ -1076,13 +1068,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# concat the emotion embedding and speaker embedding
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||
if g is None:
|
||||
g = eg
|
||||
else:
|
||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
|
@ -1097,15 +1082,15 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
# print("Encoder input", x.shape)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
# print("Y mask:", y_mask.shape)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
|
@ -1228,13 +1213,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# concat the emotion embedding and speaker embedding
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||
if g is None:
|
||||
g = eg
|
||||
else:
|
||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
|
@ -1252,7 +1230,7 @@ class Vits(BaseTTS):
|
|||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
|
@ -1274,9 +1252,7 @@ class Vits(BaseTTS):
|
|||
lang_emb=lang_emb,
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g_dp, lang_emb=lang_emb
|
||||
)
|
||||
logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
@ -1297,7 +1273,6 @@ class Vits(BaseTTS):
|
|||
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
<<<<<<< HEAD
|
||||
outputs = {
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
|
@ -1308,15 +1283,14 @@ class Vits(BaseTTS):
|
|||
"logs_p": logs_p,
|
||||
"y_mask": y_mask,
|
||||
}
|
||||
=======
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
|
||||
>>>>>>> 3a524b05... Add prosody encoder params on config
|
||||
return outputs
|
||||
|
||||
def compute_style_feature(self, style_wav_path):
|
||||
style_wav, sr = torchaudio.load(style_wav_path)
|
||||
if sr != self.config.audio.sample_rate:
|
||||
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
|
||||
raise RuntimeError(
|
||||
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||
)
|
||||
y = wav_to_spec(
|
||||
style_wav,
|
||||
self.config.audio.fft_size,
|
||||
|
@ -1491,7 +1465,7 @@ class Vits(BaseTTS):
|
|||
or self.args.use_emotion_encoder_as_loss,
|
||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
|
||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1659,7 +1633,11 @@ class Vits(BaseTTS):
|
|||
emotion_ids = None
|
||||
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
|
||||
if (
|
||||
self.speaker_manager is not None
|
||||
and self.speaker_manager.ids
|
||||
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
|
||||
):
|
||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||
|
||||
if speaker_ids is not None:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
|
|
|
@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
|
|||
SpeakerEncoder: Speaker encoder object.
|
||||
"""
|
||||
speaker_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
||||
if get_from_config_or_model_args_with_default(
|
||||
config, "use_speaker_embedding", False
|
||||
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
||||
if samples:
|
||||
speaker_manager = SpeakerManager(data_items=samples)
|
||||
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
|
||||
|
|
|
@ -181,8 +181,8 @@ def synthesis(
|
|||
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
|
||||
style_feature = style_feature.transpose(1, 2) # [1, time, depth]
|
||||
|
||||
if hasattr(model, 'compute_style_feature'):
|
||||
style_feature = model.compute_style_feature(style_wav)
|
||||
if hasattr(model, 'compute_style_feature') and style_wav is not None:
|
||||
style_feature = model.compute_style_feature(style_wav)
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
|
|
Loading…
Reference in New Issue