Clean up old code

This commit is contained in:
Edresson Casanova 2022-05-16 13:09:12 +00:00
parent 66e3f5388e
commit a543d71352
8 changed files with 45 additions and 70 deletions

View File

@ -13,9 +13,8 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args:
<<<<<<< HEAD
items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`.
A list of samples. Each sample is a dict containing the keys "text", "audio_file", and "speaker_name".
eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
=======
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
>>>>>>> Fix docstring
"""
speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1
@ -117,7 +113,9 @@ def load_tts_samples(
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
meta_data_eval = [
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval

View File

@ -1,8 +1,10 @@
import torch
from torch import nn
# pylint: disable=W0223
class GradientReversalFunction(torch.autograd.Function):
"""Revert gradient without any further input modification.
"""Revert gradient without any further input modification.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
@staticmethod
@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module):
"""
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
super(ReversalClassifier, self).__init__()
super().__init__()
self._lambda = scale_factor
self._clipping = gradient_clipping_bounds
self._out_channels = out_channels
self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
)
self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None):
def forward(self, x, labels, x_mask=None):
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x)
loss = self.loss(labels, x, x_mask)
@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module):
ml = torch.max(x_mask)
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
target = labels.repeat(ml.int().item(), 1).transpose(0,1)
target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index)
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)

View File

@ -13,9 +13,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -223,8 +223,7 @@ class Tacotron(BaseTacotron):
encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst:
# B x gst_dim
<<<<<<< HEAD
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
if self.capacitron_vae and self.use_capacitron_vae:
if aux_input["style_text"] is not None:
style_text_embedding = self.embedding(aux_input["style_text"])
@ -232,24 +231,21 @@ class Tacotron(BaseTacotron):
encoder_outputs.device
) # pylint: disable=not-callable
reference_mel_length = (
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_mel"] is not None
torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_feature"] is not None
else None
) # pylint: disable=not-callable
# B x capacitron_VAE_embedding_dim
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
if aux_input["style_mel"] is not None
reference_mel_info=[aux_input["style_feature"], reference_mel_length]
if aux_input["style_feature"] is not None
else None,
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
speaker_embedding=aux_input["d_vectors"]
if self.capacitron_vae.capacitron_use_speaker_embedding
else None,
)
=======
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
>>>>>>> 3a524b05... Add prosody encoder params on config
if self.num_speakers > 1:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim

View File

@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.generic.classifier import ReversalClassifier
##############################
# IO / Feature extraction
##############################
@ -541,8 +540,7 @@ class VitsArgs(Coqpit):
external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0
num_emotions: int = 0
emotion_just_encoder: bool = False
# prosody encoder
use_prosody_encoder: bool = False
prosody_embedding_dim: int = 0
@ -658,7 +656,7 @@ class Vits(BaseTTS):
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dp_cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
@ -872,12 +870,6 @@ class Vits(BaseTTS):
if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -1076,13 +1068,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -1097,15 +1082,15 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
# print("Encoder input", x.shape)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# print("Y mask:", y_mask.shape)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
@ -1228,13 +1213,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -1252,7 +1230,7 @@ class Vits(BaseTTS):
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
@ -1274,9 +1252,7 @@ class Vits(BaseTTS):
lang_emb=lang_emb,
)
else:
logw = self.duration_predictor(
x, x_mask, g=g_dp, lang_emb=lang_emb
)
logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
@ -1297,7 +1273,6 @@ class Vits(BaseTTS):
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
<<<<<<< HEAD
outputs = {
"model_outputs": o,
"alignments": attn.squeeze(1),
@ -1308,15 +1283,14 @@ class Vits(BaseTTS):
"logs_p": logs_p,
"y_mask": y_mask,
}
=======
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
>>>>>>> 3a524b05... Add prosody encoder params on config
return outputs
def compute_style_feature(self, style_wav_path):
style_wav, sr = torchaudio.load(style_wav_path)
if sr != self.config.audio.sample_rate:
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
raise RuntimeError(
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
)
y = wav_to_spec(
style_wav,
self.config.audio.fft_size,
@ -1491,7 +1465,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
)
return self.model_outputs_cache, loss_dict
@ -1659,7 +1633,11 @@ class Vits(BaseTTS):
emotion_ids = None
# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
if (
self.speaker_manager is not None
and self.speaker_manager.ids
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None:

View File

@ -1,10 +1,10 @@
import json
import os
import torch
import numpy as np
from typing import Any, List
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default

View File

@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if get_from_config_or_model_args_with_default(
config, "use_speaker_embedding", False
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):

View File

@ -181,8 +181,8 @@ def synthesis(
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
style_feature = style_feature.transpose(1, 2) # [1, time, depth]
if hasattr(model, 'compute_style_feature'):
style_feature = model.compute_style_feature(style_wav)
if hasattr(model, 'compute_style_feature') and style_wav is not None:
style_feature = model.compute_style_feature(style_wav)
# convert text to sequence of token IDs
text_inputs = np.asarray(