Clean up old code

This commit is contained in:
Edresson Casanova 2022-05-16 13:09:12 +00:00
parent dbaa71c944
commit 024e567849
8 changed files with 45 additions and 70 deletions

View File

@ -13,9 +13,8 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
"""Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training.
Args: Args:
<<<<<<< HEAD
items (List[List]): items (List[List]):
A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. A list of samples. Each sample is a dict containing the keys "text", "audio_file", and "speaker_name".
eval_split_max_size (int): eval_split_max_size (int):
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01):
eval_split_size (float): eval_split_size (float):
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
=======
items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`.
>>>>>>> Fix docstring
""" """
speakers = [item["speaker_name"] for item in items] speakers = [item["speaker_name"] for item in items]
is_multi_speaker = len(set(speakers)) > 1 is_multi_speaker = len(set(speakers)) > 1
@ -117,7 +113,9 @@ def load_tts_samples(
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval] meta_data_eval = [
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval

View File

@ -1,8 +1,10 @@
import torch import torch
from torch import nn from torch import nn
# pylint: disable=W0223
class GradientReversalFunction(torch.autograd.Function): class GradientReversalFunction(torch.autograd.Function):
"""Revert gradient without any further input modification. """Revert gradient without any further input modification.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/""" Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
@staticmethod @staticmethod
@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module):
""" """
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
super(ReversalClassifier, self).__init__() super().__init__()
self._lambda = scale_factor self._lambda = scale_factor
self._clipping = gradient_clipping_bounds self._clipping = gradient_clipping_bounds
self._out_channels = out_channels self._out_channels = out_channels
self._classifier = nn.Sequential( self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels), nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
) )
self.test = nn.Linear(in_channels, hidden_channels) self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None):
def forward(self, x, labels, x_mask=None):
x = GradientReversalFunction.apply(x, self._lambda, self._clipping) x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x) x = self._classifier(x)
loss = self.loss(labels, x, x_mask) loss = self.loss(labels, x, x_mask)
@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module):
ml = torch.max(x_mask) ml = torch.max(x_mask)
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
target = labels.repeat(ml.int().item(), 1).transpose(0,1) target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
target[~input_mask] = ignore_index target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index) return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)

View File

@ -13,9 +13,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -223,8 +223,7 @@ class Tacotron(BaseTacotron):
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst: if self.gst and self.use_gst:
# B x gst_dim # B x gst_dim
<<<<<<< HEAD encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.capacitron_vae and self.use_capacitron_vae: if self.capacitron_vae and self.use_capacitron_vae:
if aux_input["style_text"] is not None: if aux_input["style_text"] is not None:
style_text_embedding = self.embedding(aux_input["style_text"]) style_text_embedding = self.embedding(aux_input["style_text"])
@ -232,24 +231,21 @@ class Tacotron(BaseTacotron):
encoder_outputs.device encoder_outputs.device
) # pylint: disable=not-callable ) # pylint: disable=not-callable
reference_mel_length = ( reference_mel_length = (
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_mel"] is not None if aux_input["style_feature"] is not None
else None else None
) # pylint: disable=not-callable ) # pylint: disable=not-callable
# B x capacitron_VAE_embedding_dim # B x capacitron_VAE_embedding_dim
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
encoder_outputs, encoder_outputs,
reference_mel_info=[aux_input["style_mel"], reference_mel_length] reference_mel_info=[aux_input["style_feature"], reference_mel_length]
if aux_input["style_mel"] is not None if aux_input["style_feature"] is not None
else None, else None,
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
speaker_embedding=aux_input["d_vectors"] speaker_embedding=aux_input["d_vectors"]
if self.capacitron_vae.capacitron_use_speaker_embedding if self.capacitron_vae.capacitron_use_speaker_embedding
else None, else None,
) )
=======
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
>>>>>>> 3a524b05... Add prosody encoder params on config
if self.num_speakers > 1: if self.num_speakers > 1:
if not self.use_d_vector_file: if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim # B x 1 x speaker_embed_dim

View File

@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.generic.classifier import ReversalClassifier
############################## ##############################
# IO / Feature extraction # IO / Feature extraction
############################## ##############################
@ -541,8 +540,7 @@ class VitsArgs(Coqpit):
external_emotions_embs_file: str = None external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0 emotion_embedding_dim: int = 0
num_emotions: int = 0 num_emotions: int = 0
emotion_just_encoder: bool = False
# prosody encoder # prosody encoder
use_prosody_encoder: bool = False use_prosody_encoder: bool = False
prosody_embedding_dim: int = 0 prosody_embedding_dim: int = 0
@ -658,7 +656,7 @@ class Vits(BaseTTS):
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0 dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dp_cond_embedding_dim += self.args.emotion_embedding_dim dp_cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
@ -872,12 +870,6 @@ class Vits(BaseTTS):
if self.num_emotions > 0: if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.") print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict): def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -1076,13 +1068,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None: if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1097,15 +1082,15 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2) pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
# print("Encoder input", x.shape)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
# flow layers # flow layers
z_p = self.flow(z, y_mask, g=g) z_p = self.flow(z, y_mask, g=g)
# print("Y mask:", y_mask.shape)
# duration predictor # duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None: if g_dp is None:
g_dp = eg g_dp = eg
else: else:
@ -1228,13 +1213,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None: if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1252,7 +1230,7 @@ class Vits(BaseTTS):
# duration predictor # duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None: if g_dp is None:
g_dp = eg g_dp = eg
else: else:
@ -1274,9 +1252,7 @@ class Vits(BaseTTS):
lang_emb=lang_emb, lang_emb=lang_emb,
) )
else: else:
logw = self.duration_predictor( logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
x, x_mask, g=g_dp, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)
@ -1297,7 +1273,6 @@ class Vits(BaseTTS):
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
<<<<<<< HEAD
outputs = { outputs = {
"model_outputs": o, "model_outputs": o,
"alignments": attn.squeeze(1), "alignments": attn.squeeze(1),
@ -1308,15 +1283,14 @@ class Vits(BaseTTS):
"logs_p": logs_p, "logs_p": logs_p,
"y_mask": y_mask, "y_mask": y_mask,
} }
=======
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
>>>>>>> 3a524b05... Add prosody encoder params on config
return outputs return outputs
def compute_style_feature(self, style_wav_path): def compute_style_feature(self, style_wav_path):
style_wav, sr = torchaudio.load(style_wav_path) style_wav, sr = torchaudio.load(style_wav_path)
if sr != self.config.audio.sample_rate: if sr != self.config.audio.sample_rate:
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!") raise RuntimeError(
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
)
y = wav_to_spec( y = wav_to_spec(
style_wav, style_wav,
self.config.audio.fft_size, self.config.audio.fft_size,
@ -1491,7 +1465,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss, or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"] loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -1659,7 +1633,11 @@ class Vits(BaseTTS):
emotion_ids = None emotion_ids = None
# get numerical speaker ids from speaker names # get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder): if (
self.speaker_manager is not None
and self.speaker_manager.ids
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None: if speaker_ids is not None:

View File

@ -1,10 +1,10 @@
import json import json
import os import os
import torch
import numpy as np
from typing import Any, List from typing import Any, List
import fsspec import fsspec
import numpy as np
import torch
from coqpit import Coqpit from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default from TTS.config import get_from_config_or_model_args_with_default

View File

@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
SpeakerEncoder: Speaker encoder object. SpeakerEncoder: Speaker encoder object.
""" """
speaker_manager = None speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): if get_from_config_or_model_args_with_default(
config, "use_speaker_embedding", False
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if samples: if samples:
speaker_manager = SpeakerManager(data_items=samples) speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None): if get_from_config_or_model_args_with_default(config, "speaker_file", None):

View File

@ -181,8 +181,8 @@ def synthesis(
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda) style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
style_feature = style_feature.transpose(1, 2) # [1, time, depth] style_feature = style_feature.transpose(1, 2) # [1, time, depth]
if hasattr(model, 'compute_style_feature'): if hasattr(model, 'compute_style_feature') and style_wav is not None:
style_feature = model.compute_style_feature(style_wav) style_feature = model.compute_style_feature(style_wav)
# convert text to sequence of token IDs # convert text to sequence of token IDs
text_inputs = np.asarray( text_inputs = np.asarray(