Clean up old code

This commit is contained in:
Edresson Casanova 2022-05-16 13:09:12 +00:00
parent 3a524b0597
commit dcd0d1f6a1
7 changed files with 47 additions and 53 deletions

View File

@ -117,7 +117,9 @@ def load_tts_samples(
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
meta_data_eval = [
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval

View File

@ -1,8 +1,10 @@
import torch
from torch import nn
# pylint: disable=W0223
class GradientReversalFunction(torch.autograd.Function):
"""Revert gradient without any further input modification.
"""Revert gradient without any further input modification.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
@staticmethod
@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module):
"""
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
super(ReversalClassifier, self).__init__()
super().__init__()
self._lambda = scale_factor
self._clipping = gradient_clipping_bounds
self._out_channels = out_channels
self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
)
self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None):
def forward(self, x, labels, x_mask=None):
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x)
loss = self.loss(labels, x, x_mask)
@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module):
ml = torch.max(x_mask)
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
target = labels.repeat(ml.int().item(), 1).transpose(0,1)
target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
target[~input_mask] = ignore_index
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index)
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)

View File

@ -12,9 +12,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.generic.classifier import ReversalClassifier
##############################
# IO / Feature extraction
##############################
@ -503,8 +502,7 @@ class VitsArgs(Coqpit):
external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0
num_emotions: int = 0
emotion_just_encoder: bool = False
# prosody encoder
use_prosody_encoder: bool = False
prosody_embedding_dim: int = 0
@ -615,7 +613,7 @@ class Vits(BaseTTS):
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dp_cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
@ -796,12 +794,6 @@ class Vits(BaseTTS):
if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -983,13 +975,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -1004,15 +989,15 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
# print("Encoder input", x.shape)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# print("Y mask:", y_mask.shape)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
@ -1130,13 +1115,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding
lang_emb = None
if self.args.use_language_embedding and lid is not None:
@ -1154,7 +1132,7 @@ class Vits(BaseTTS):
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None:
g_dp = eg
else:
@ -1176,9 +1154,7 @@ class Vits(BaseTTS):
lang_emb=lang_emb,
)
else:
logw = self.duration_predictor(
x, x_mask, g=g_dp, lang_emb=lang_emb
)
logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
@ -1195,13 +1171,23 @@ class Vits(BaseTTS):
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
outputs = {
"model_outputs": o,
"alignments": attn.squeeze(1),
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"durations": w_ceil,
}
return outputs
def compute_style_feature(self, style_wav_path):
style_wav, sr = torchaudio.load(style_wav_path)
if sr != self.config.audio.sample_rate:
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
raise RuntimeError(
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
)
y = wav_to_spec(
style_wav,
self.config.audio.fft_size,
@ -1371,7 +1357,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
)
return self.model_outputs_cache, loss_dict
@ -1539,7 +1525,11 @@ class Vits(BaseTTS):
emotion_ids = None
# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
if (
self.speaker_manager is not None
and self.speaker_manager.ids
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None:

View File

@ -1,10 +1,10 @@
import json
import os
import torch
import numpy as np
from typing import Any, List
import fsspec
import numpy as np
import torch
from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default

View File

@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
SpeakerEncoder: Speaker encoder object.
"""
speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if get_from_config_or_model_args_with_default(
config, "use_speaker_embedding", False
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if samples:
speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None):

View File

@ -168,10 +168,9 @@ def synthesis(
style_feature = style_wav
else:
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
if hasattr(model, 'compute_style_feature'):
if hasattr(model, "compute_style_feature"):
style_feature = model.compute_style_feature(style_wav)
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),