mirror of https://github.com/coqui-ai/TTS.git
Clean up old code
This commit is contained in:
parent
3a524b0597
commit
dcd0d1f6a1
|
@ -117,7 +117,9 @@ def load_tts_samples(
|
|||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
|
||||
meta_data_eval = [
|
||||
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
|
||||
]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# pylint: disable=W0223
|
||||
class GradientReversalFunction(torch.autograd.Function):
|
||||
"""Revert gradient without any further input modification.
|
||||
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
|
||||
|
@ -30,16 +32,15 @@ class ReversalClassifier(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
||||
super(ReversalClassifier, self).__init__()
|
||||
super().__init__()
|
||||
self._lambda = scale_factor
|
||||
self._clipping = gradient_clipping_bounds
|
||||
self._out_channels = out_channels
|
||||
self._classifier = nn.Sequential(
|
||||
nn.Linear(in_channels, hidden_channels),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels, out_channels)
|
||||
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
||||
)
|
||||
self.test = nn.Linear(in_channels, hidden_channels)
|
||||
|
||||
def forward(self, x, labels, x_mask=None):
|
||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||
x = self._classifier(x)
|
||||
|
@ -55,7 +56,7 @@ class ReversalClassifier(nn.Module):
|
|||
ml = torch.max(x_mask)
|
||||
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
|
||||
|
||||
target = labels.repeat(ml.int().item(), 1).transpose(0,1)
|
||||
target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
|
||||
target[~input_mask] = ignore_index
|
||||
|
||||
return nn.functional.cross_entropy(predictions.transpose(1,2), target, ignore_index=ignore_index)
|
||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||
|
|
|
@ -12,9 +12,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|||
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
|
|||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
|
|||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
|
||||
##############################
|
||||
# IO / Feature extraction
|
||||
##############################
|
||||
|
@ -503,7 +502,6 @@ class VitsArgs(Coqpit):
|
|||
external_emotions_embs_file: str = None
|
||||
emotion_embedding_dim: int = 0
|
||||
num_emotions: int = 0
|
||||
emotion_just_encoder: bool = False
|
||||
|
||||
# prosody encoder
|
||||
use_prosody_encoder: bool = False
|
||||
|
@ -615,7 +613,7 @@ class Vits(BaseTTS):
|
|||
|
||||
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
|
||||
|
||||
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
|
@ -796,12 +794,6 @@ class Vits(BaseTTS):
|
|||
if self.num_emotions > 0:
|
||||
print(" > initialization of emotion-embedding layers.")
|
||||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||
if not self.args.emotion_just_encoder:
|
||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_external_emotions_embeddings:
|
||||
if not self.args.emotion_just_encoder:
|
||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||
|
@ -983,13 +975,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# concat the emotion embedding and speaker embedding
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||
if g is None:
|
||||
g = eg
|
||||
else:
|
||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
|
@ -1004,15 +989,15 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder:
|
||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||
# print("Encoder input", x.shape)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
|
||||
|
||||
# flow layers
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
# print("Y mask:", y_mask.shape)
|
||||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
|
@ -1130,13 +1115,6 @@ class Vits(BaseTTS):
|
|||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# concat the emotion embedding and speaker embedding
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||
if g is None:
|
||||
g = eg
|
||||
else:
|
||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
# language embedding
|
||||
lang_emb = None
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
|
@ -1154,7 +1132,7 @@ class Vits(BaseTTS):
|
|||
|
||||
# duration predictor
|
||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dp is None:
|
||||
g_dp = eg
|
||||
else:
|
||||
|
@ -1176,9 +1154,7 @@ class Vits(BaseTTS):
|
|||
lang_emb=lang_emb,
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g_dp, lang_emb=lang_emb
|
||||
)
|
||||
logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
@ -1195,13 +1171,23 @@ class Vits(BaseTTS):
|
|||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||
|
||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
|
||||
outputs = {
|
||||
"model_outputs": o,
|
||||
"alignments": attn.squeeze(1),
|
||||
"z": z,
|
||||
"z_p": z_p,
|
||||
"m_p": m_p,
|
||||
"logs_p": logs_p,
|
||||
"durations": w_ceil,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def compute_style_feature(self, style_wav_path):
|
||||
style_wav, sr = torchaudio.load(style_wav_path)
|
||||
if sr != self.config.audio.sample_rate:
|
||||
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
|
||||
raise RuntimeError(
|
||||
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||
)
|
||||
y = wav_to_spec(
|
||||
style_wav,
|
||||
self.config.audio.fft_size,
|
||||
|
@ -1371,7 +1357,7 @@ class Vits(BaseTTS):
|
|||
or self.args.use_emotion_encoder_as_loss,
|
||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
|
||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -1539,7 +1525,11 @@ class Vits(BaseTTS):
|
|||
emotion_ids = None
|
||||
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
|
||||
if (
|
||||
self.speaker_manager is not None
|
||||
and self.speaker_manager.ids
|
||||
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
|
||||
):
|
||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||
|
||||
if speaker_ids is not None:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, List
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import get_from_config_or_model_args_with_default
|
||||
|
|
|
@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
|
|||
SpeakerEncoder: Speaker encoder object.
|
||||
"""
|
||||
speaker_manager = None
|
||||
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
||||
if get_from_config_or_model_args_with_default(
|
||||
config, "use_speaker_embedding", False
|
||||
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
||||
if samples:
|
||||
speaker_manager = SpeakerManager(data_items=samples)
|
||||
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
|
||||
|
|
|
@ -168,10 +168,9 @@ def synthesis(
|
|||
style_feature = style_wav
|
||||
else:
|
||||
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
|
||||
if hasattr(model, 'compute_style_feature'):
|
||||
if hasattr(model, "compute_style_feature"):
|
||||
style_feature = model.compute_style_feature(style_wav)
|
||||
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
model.tokenizer.text_to_ids(text, language=language_id),
|
||||
|
|
Loading…
Reference in New Issue