Clean up old code

This commit is contained in:
Edresson Casanova 2022-05-16 13:09:12 +00:00
parent 3a524b0597
commit dcd0d1f6a1
7 changed files with 47 additions and 53 deletions

View File

@ -117,7 +117,9 @@ def load_tts_samples(
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval] meta_data_eval = [
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval

View File

@ -1,6 +1,8 @@
import torch import torch
from torch import nn from torch import nn
# pylint: disable=W0223
class GradientReversalFunction(torch.autograd.Function): class GradientReversalFunction(torch.autograd.Function):
"""Revert gradient without any further input modification. """Revert gradient without any further input modification.
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/""" Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
@ -30,16 +32,15 @@ class ReversalClassifier(nn.Module):
""" """
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
super(ReversalClassifier, self).__init__() super().__init__()
self._lambda = scale_factor self._lambda = scale_factor
self._clipping = gradient_clipping_bounds self._clipping = gradient_clipping_bounds
self._out_channels = out_channels self._out_channels = out_channels
self._classifier = nn.Sequential( self._classifier = nn.Sequential(
nn.Linear(in_channels, hidden_channels), nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
) )
self.test = nn.Linear(in_channels, hidden_channels) self.test = nn.Linear(in_channels, hidden_channels)
def forward(self, x, labels, x_mask=None): def forward(self, x, labels, x_mask=None):
x = GradientReversalFunction.apply(x, self._lambda, self._clipping) x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
x = self._classifier(x) x = self._classifier(x)

View File

@ -12,9 +12,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from TTS.model import BaseTrainerModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram

View File

@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.generic.classifier import ReversalClassifier
############################## ##############################
# IO / Feature extraction # IO / Feature extraction
############################## ##############################
@ -503,7 +502,6 @@ class VitsArgs(Coqpit):
external_emotions_embs_file: str = None external_emotions_embs_file: str = None
emotion_embedding_dim: int = 0 emotion_embedding_dim: int = 0
num_emotions: int = 0 num_emotions: int = 0
emotion_just_encoder: bool = False
# prosody encoder # prosody encoder
use_prosody_encoder: bool = False use_prosody_encoder: bool = False
@ -615,7 +613,7 @@ class Vits(BaseTTS):
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0 dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dp_cond_embedding_dim += self.args.emotion_embedding_dim dp_cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
@ -796,12 +794,6 @@ class Vits(BaseTTS):
if self.num_emotions > 0: if self.num_emotions > 0:
print(" > initialization of emotion-embedding layers.") print(" > initialization of emotion-embedding layers.")
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
if self.args.use_external_emotions_embeddings:
if not self.args.emotion_just_encoder:
self.cond_embedding_dim += self.args.emotion_embedding_dim
def get_aux_input(self, aux_input: Dict): def get_aux_input(self, aux_input: Dict):
sid, g, lid, eid, eg = self._set_cond_input(aux_input) sid, g, lid, eid, eg = self._set_cond_input(aux_input)
@ -983,13 +975,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None: if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1004,15 +989,15 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2) pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
# print("Encoder input", x.shape)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
# flow layers # flow layers
z_p = self.flow(z, y_mask, g=g) z_p = self.flow(z, y_mask, g=g)
# print("Y mask:", y_mask.shape)
# duration predictor # duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None: if g_dp is None:
g_dp = eg g_dp = eg
else: else:
@ -1130,13 +1115,6 @@ class Vits(BaseTTS):
if self.args.use_emotion_embedding and eid is not None and eg is None: if self.args.use_emotion_embedding and eid is not None and eg is None:
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
# concat the emotion embedding and speaker embedding
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
if g is None:
g = eg
else:
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
# language embedding # language embedding
lang_emb = None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
@ -1154,7 +1132,7 @@ class Vits(BaseTTS):
# duration predictor # duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dp is None: if g_dp is None:
g_dp = eg g_dp = eg
else: else:
@ -1176,9 +1154,7 @@ class Vits(BaseTTS):
lang_emb=lang_emb, lang_emb=lang_emb,
) )
else: else:
logw = self.duration_predictor( logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
x, x_mask, g=g_dp, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)
@ -1195,13 +1171,23 @@ class Vits(BaseTTS):
z = self.flow(z_p, y_mask, g=g, reverse=True) z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil} outputs = {
"model_outputs": o,
"alignments": attn.squeeze(1),
"z": z,
"z_p": z_p,
"m_p": m_p,
"logs_p": logs_p,
"durations": w_ceil,
}
return outputs return outputs
def compute_style_feature(self, style_wav_path): def compute_style_feature(self, style_wav_path):
style_wav, sr = torchaudio.load(style_wav_path) style_wav, sr = torchaudio.load(style_wav_path)
if sr != self.config.audio.sample_rate: if sr != self.config.audio.sample_rate:
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!") raise RuntimeError(
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
)
y = wav_to_spec( y = wav_to_spec(
style_wav, style_wav,
self.config.audio.fft_size, self.config.audio.fft_size,
@ -1371,7 +1357,7 @@ class Vits(BaseTTS):
or self.args.use_emotion_encoder_as_loss, or self.args.use_emotion_encoder_as_loss,
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"] loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -1539,7 +1525,11 @@ class Vits(BaseTTS):
emotion_ids = None emotion_ids = None
# get numerical speaker ids from speaker names # get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder): if (
self.speaker_manager is not None
and self.speaker_manager.ids
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
):
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
if speaker_ids is not None: if speaker_ids is not None:

View File

@ -1,10 +1,10 @@
import json import json
import os import os
import torch
import numpy as np
from typing import Any, List from typing import Any, List
import fsspec import fsspec
import numpy as np
import torch
from coqpit import Coqpit from coqpit import Coqpit
from TTS.config import get_from_config_or_model_args_with_default from TTS.config import get_from_config_or_model_args_with_default

View File

@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
SpeakerEncoder: Speaker encoder object. SpeakerEncoder: Speaker encoder object.
""" """
speaker_manager = None speaker_manager = None
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): if get_from_config_or_model_args_with_default(
config, "use_speaker_embedding", False
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
if samples: if samples:
speaker_manager = SpeakerManager(data_items=samples) speaker_manager = SpeakerManager(data_items=samples)
if get_from_config_or_model_args_with_default(config, "speaker_file", None): if get_from_config_or_model_args_with_default(config, "speaker_file", None):

View File

@ -168,10 +168,9 @@ def synthesis(
style_feature = style_wav style_feature = style_wav
else: else:
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda) style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
if hasattr(model, 'compute_style_feature'): if hasattr(model, "compute_style_feature"):
style_feature = model.compute_style_feature(style_wav) style_feature = model.compute_style_feature(style_wav)
# convert text to sequence of token IDs # convert text to sequence of token IDs
text_inputs = np.asarray( text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id), model.tokenizer.text_to_ids(text, language=language_id),