From dcd0d1f6a139c484e23220f915eba463d828cf01 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 16 May 2022 13:09:12 +0000 Subject: [PATCH] Clean up old code --- TTS/tts/datasets/__init__.py | 4 +- TTS/tts/layers/generic/classifier.py | 17 +++---- TTS/tts/models/base_tts.py | 2 +- TTS/tts/models/vits.py | 66 ++++++++++++---------------- TTS/tts/utils/emotions.py | 4 +- TTS/tts/utils/speakers.py | 4 +- TTS/tts/utils/synthesis.py | 3 +- 7 files changed, 47 insertions(+), 53 deletions(-) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index bb2a823e..8e75be4c 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -117,7 +117,9 @@ def load_tts_samples( if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval] + meta_data_eval = [ + {**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval + ] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py index b0136625..938cbdb8 100644 --- a/TTS/tts/layers/generic/classifier.py +++ b/TTS/tts/layers/generic/classifier.py @@ -1,8 +1,10 @@ import torch from torch import nn + +# pylint: disable=W0223 class GradientReversalFunction(torch.autograd.Function): - """Revert gradient without any further input modification. + """Revert gradient without any further input modification. Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/""" @staticmethod @@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module): """ def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): - super(ReversalClassifier, self).__init__() + super().__init__() self._lambda = scale_factor self._clipping = gradient_clipping_bounds self._out_channels = out_channels self._classifier = nn.Sequential( - nn.Linear(in_channels, hidden_channels), - nn.ReLU(), - nn.Linear(hidden_channels, out_channels) + nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels) ) self.test = nn.Linear(in_channels, hidden_channels) - def forward(self, x, labels, x_mask=None): + + def forward(self, x, labels, x_mask=None): x = GradientReversalFunction.apply(x, self._lambda, self._clipping) x = self._classifier(x) loss = self.loss(labels, x, x_mask) @@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module): ml = torch.max(x_mask) input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] - target = labels.repeat(ml.int().item(), 1).transpose(0,1) + target = labels.repeat(ml.int().item(), 1).transpose(0, 1) target[~input_mask] = ignore_index - return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index) + return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 2ccc6143..fc77d682 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -12,9 +12,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.emotions import get_speech_style_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager -from TTS.tts.utils.emotions import get_speech_style_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 8a177f91..2a9a4bbb 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor @@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results -from TTS.tts.layers.tacotron.gst_layers import GST -from TTS.tts.layers.generic.classifier import ReversalClassifier - ############################## # IO / Feature extraction ############################## @@ -503,8 +502,7 @@ class VitsArgs(Coqpit): external_emotions_embs_file: str = None emotion_embedding_dim: int = 0 num_emotions: int = 0 - emotion_just_encoder: bool = False - + # prosody encoder use_prosody_encoder: bool = False prosody_embedding_dim: int = 0 @@ -615,7 +613,7 @@ class Vits(BaseTTS): dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0 - if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: dp_cond_embedding_dim += self.args.emotion_embedding_dim if self.args.use_prosody_encoder: @@ -796,12 +794,6 @@ class Vits(BaseTTS): if self.num_emotions > 0: print(" > initialization of emotion-embedding layers.") self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) - if not self.args.emotion_just_encoder: - self.cond_embedding_dim += self.args.emotion_embedding_dim - - if self.args.use_external_emotions_embeddings: - if not self.args.emotion_just_encoder: - self.cond_embedding_dim += self.args.emotion_embedding_dim def get_aux_input(self, aux_input: Dict): sid, g, lid, eid, eg = self._set_cond_input(aux_input) @@ -983,13 +975,6 @@ class Vits(BaseTTS): if self.args.use_emotion_embedding and eid is not None and eg is None: eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] - # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: - if g is None: - g = eg - else: - g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] - # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1004,15 +989,15 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) - # print("Encoder input", x.shape) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) - # print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape) + # flow layers z_p = self.flow(z, y_mask, g=g) - # print("Y mask:", y_mask.shape) + # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if g_dp is None: g_dp = eg else: @@ -1130,13 +1115,6 @@ class Vits(BaseTTS): if self.args.use_emotion_embedding and eid is not None and eg is None: eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] - # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: - if g is None: - g = eg - else: - g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] - # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1154,7 +1132,7 @@ class Vits(BaseTTS): # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if g_dp is None: g_dp = eg else: @@ -1176,9 +1154,7 @@ class Vits(BaseTTS): lang_emb=lang_emb, ) else: - logw = self.duration_predictor( - x, x_mask, g=g_dp, lang_emb=lang_emb - ) + logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) @@ -1195,13 +1171,23 @@ class Vits(BaseTTS): z = self.flow(z_p, y_mask, g=g, reverse=True) o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) - outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil} + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "durations": w_ceil, + } return outputs def compute_style_feature(self, style_wav_path): style_wav, sr = torchaudio.load(style_wav_path) if sr != self.config.audio.sample_rate: - raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!") + raise RuntimeError( + " [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!" + ) y = wav_to_spec( style_wav, self.config.audio.fft_size, @@ -1371,7 +1357,7 @@ class Vits(BaseTTS): or self.args.use_emotion_encoder_as_loss, gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], - loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"] + loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"], ) return self.model_outputs_cache, loss_dict @@ -1539,7 +1525,11 @@ class Vits(BaseTTS): emotion_ids = None # get numerical speaker ids from speaker names - if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder): + if ( + self.speaker_manager is not None + and self.speaker_manager.ids + and (self.args.use_speaker_embedding or self.args.use_prosody_encoder) + ): speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] if speaker_ids is not None: diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 1fea49ae..bf5646a9 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -1,10 +1,10 @@ import json import os -import torch -import numpy as np from typing import Any, List import fsspec +import numpy as np +import torch from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 099fe029..c9ba9db4 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager): SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): + if get_from_config_or_model_args_with_default( + config, "use_speaker_embedding", False + ) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): if samples: speaker_manager = SpeakerManager(data_items=samples) if get_from_config_or_model_args_with_default(config, "speaker_file", None): diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index ac92d345..487fdc65 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -168,10 +168,9 @@ def synthesis( style_feature = style_wav else: style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda) - if hasattr(model, 'compute_style_feature'): + if hasattr(model, "compute_style_feature"): style_feature = model.compute_style_feature(style_wav) - # convert text to sequence of token IDs text_inputs = np.asarray( model.tokenizer.text_to_ids(text, language=language_id),