diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index bb2a823e..2eed947f 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -13,9 +13,8 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. Args: - <<<<<<< HEAD items (List[List]): - A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + A list of samples. Each sample is a dict containing the keys "text", "audio_file", and "speaker_name". eval_split_max_size (int): Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). @@ -23,9 +22,6 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): eval_split_size (float): If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). - ======= - items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. - >>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 @@ -117,7 +113,9 @@ def load_tts_samples( if eval_split: if meta_file_val: meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) - meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval] + meta_data_eval = [ + {**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval + ] else: meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) meta_data_eval_all += meta_data_eval diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py index b0136625..938cbdb8 100644 --- a/TTS/tts/layers/generic/classifier.py +++ b/TTS/tts/layers/generic/classifier.py @@ -1,8 +1,10 @@ import torch from torch import nn + +# pylint: disable=W0223 class GradientReversalFunction(torch.autograd.Function): - """Revert gradient without any further input modification. + """Revert gradient without any further input modification. Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/""" @staticmethod @@ -30,17 +32,16 @@ class ReversalClassifier(nn.Module): """ def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0): - super(ReversalClassifier, self).__init__() + super().__init__() self._lambda = scale_factor self._clipping = gradient_clipping_bounds self._out_channels = out_channels self._classifier = nn.Sequential( - nn.Linear(in_channels, hidden_channels), - nn.ReLU(), - nn.Linear(hidden_channels, out_channels) + nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels) ) self.test = nn.Linear(in_channels, hidden_channels) - def forward(self, x, labels, x_mask=None): + + def forward(self, x, labels, x_mask=None): x = GradientReversalFunction.apply(x, self._lambda, self._clipping) x = self._classifier(x) loss = self.loss(labels, x, x_mask) @@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module): ml = torch.max(x_mask) input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None] - target = labels.repeat(ml.int().item(), 1).transpose(0,1) + target = labels.repeat(ml.int().item(), 1).transpose(0, 1) target[~input_mask] = ignore_index - return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index) + return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 501d1c41..a8022f42 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -13,9 +13,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.emotions import get_speech_style_balancer_weights from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager -from TTS.tts.utils.emotions import get_speech_style_balancer_weights from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 1f3fd6d6..c7b5ed53 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -223,8 +223,7 @@ class Tacotron(BaseTacotron): encoder_outputs = self.encoder(inputs) if self.gst and self.use_gst: # B x gst_dim -<<<<<<< HEAD - encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"]) if self.capacitron_vae and self.use_capacitron_vae: if aux_input["style_text"] is not None: style_text_embedding = self.embedding(aux_input["style_text"]) @@ -232,24 +231,21 @@ class Tacotron(BaseTacotron): encoder_outputs.device ) # pylint: disable=not-callable reference_mel_length = ( - torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) - if aux_input["style_mel"] is not None + torch.tensor([aux_input["style_feature"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_feature"] is not None else None ) # pylint: disable=not-callable # B x capacitron_VAE_embedding_dim encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( encoder_outputs, - reference_mel_info=[aux_input["style_mel"], reference_mel_length] - if aux_input["style_mel"] is not None + reference_mel_info=[aux_input["style_feature"], reference_mel_length] + if aux_input["style_feature"] is not None else None, text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, speaker_embedding=aux_input["d_vectors"] if self.capacitron_vae.capacitron_use_speaker_embedding else None, ) -======= - encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"]) ->>>>>>> 3a524b05... Add prosody encoder params on config if self.num_speakers > 1: if not self.use_d_vector_file: # B x 1 x speaker_embed_dim diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 0c4a0539..21333178 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor @@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results -from TTS.tts.layers.tacotron.gst_layers import GST -from TTS.tts.layers.generic.classifier import ReversalClassifier - ############################## # IO / Feature extraction ############################## @@ -541,8 +540,7 @@ class VitsArgs(Coqpit): external_emotions_embs_file: str = None emotion_embedding_dim: int = 0 num_emotions: int = 0 - emotion_just_encoder: bool = False - + # prosody encoder use_prosody_encoder: bool = False prosody_embedding_dim: int = 0 @@ -658,7 +656,7 @@ class Vits(BaseTTS): dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0 - if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: dp_cond_embedding_dim += self.args.emotion_embedding_dim if self.args.use_prosody_encoder: @@ -872,12 +870,6 @@ class Vits(BaseTTS): if self.num_emotions > 0: print(" > initialization of emotion-embedding layers.") self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim) - if not self.args.emotion_just_encoder: - self.cond_embedding_dim += self.args.emotion_embedding_dim - - if self.args.use_external_emotions_embeddings: - if not self.args.emotion_just_encoder: - self.cond_embedding_dim += self.args.emotion_embedding_dim def get_aux_input(self, aux_input: Dict): sid, g, lid, eid, eg = self._set_cond_input(aux_input) @@ -1076,13 +1068,6 @@ class Vits(BaseTTS): if self.args.use_emotion_embedding and eid is not None and eg is None: eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] - # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: - if g is None: - g = eg - else: - g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] - # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1097,15 +1082,15 @@ class Vits(BaseTTS): if self.args.use_prosody_encoder: pros_emb = self.prosody_encoder(z).transpose(1, 2) _, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None) - # print("Encoder input", x.shape) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb) - # print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape) + # flow layers z_p = self.flow(z, y_mask, g=g) - # print("Y mask:", y_mask.shape) + # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if g_dp is None: g_dp = eg else: @@ -1228,13 +1213,6 @@ class Vits(BaseTTS): if self.args.use_emotion_embedding and eid is not None and eg is None: eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] - # concat the emotion embedding and speaker embedding - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder: - if g is None: - g = eg - else: - g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] - # language embedding lang_emb = None if self.args.use_language_embedding and lid is not None: @@ -1252,7 +1230,7 @@ class Vits(BaseTTS): # duration predictor g_dp = g if self.args.condition_dp_on_speaker else None - if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder: + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): if g_dp is None: g_dp = eg else: @@ -1274,9 +1252,7 @@ class Vits(BaseTTS): lang_emb=lang_emb, ) else: - logw = self.duration_predictor( - x, x_mask, g=g_dp, lang_emb=lang_emb - ) + logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) @@ -1297,7 +1273,6 @@ class Vits(BaseTTS): o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) -<<<<<<< HEAD outputs = { "model_outputs": o, "alignments": attn.squeeze(1), @@ -1308,15 +1283,14 @@ class Vits(BaseTTS): "logs_p": logs_p, "y_mask": y_mask, } -======= - outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil} ->>>>>>> 3a524b05... Add prosody encoder params on config return outputs def compute_style_feature(self, style_wav_path): style_wav, sr = torchaudio.load(style_wav_path) if sr != self.config.audio.sample_rate: - raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!") + raise RuntimeError( + " [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!" + ) y = wav_to_spec( style_wav, self.config.audio.fft_size, @@ -1491,7 +1465,7 @@ class Vits(BaseTTS): or self.args.use_emotion_encoder_as_loss, gt_cons_emb=self.model_outputs_cache["gt_cons_emb"], syn_cons_emb=self.model_outputs_cache["syn_cons_emb"], - loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"] + loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"], ) return self.model_outputs_cache, loss_dict @@ -1659,7 +1633,11 @@ class Vits(BaseTTS): emotion_ids = None # get numerical speaker ids from speaker names - if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder): + if ( + self.speaker_manager is not None + and self.speaker_manager.ids + and (self.args.use_speaker_embedding or self.args.use_prosody_encoder) + ): speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] if speaker_ids is not None: diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 1fea49ae..bf5646a9 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -1,10 +1,10 @@ import json import os -import torch -import numpy as np from typing import Any, List import fsspec +import numpy as np +import torch from coqpit import Coqpit from TTS.config import get_from_config_or_model_args_with_default diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index c16f71d0..cc84e4c7 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager): SpeakerEncoder: Speaker encoder object. """ speaker_manager = None - if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): + if get_from_config_or_model_args_with_default( + config, "use_speaker_embedding", False + ) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False): if samples: speaker_manager = SpeakerManager(data_items=samples) if get_from_config_or_model_args_with_default(config, "speaker_file", None): diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 86f9ed6e..e769648d 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -181,8 +181,8 @@ def synthesis( style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda) style_feature = style_feature.transpose(1, 2) # [1, time, depth] - if hasattr(model, 'compute_style_feature'): - style_feature = model.compute_style_feature(style_wav) + if hasattr(model, 'compute_style_feature') and style_wav is not None: + style_feature = model.compute_style_feature(style_wav) # convert text to sequence of token IDs text_inputs = np.asarray(