mirror of https://github.com/coqui-ai/TTS.git
Clean up old code
This commit is contained in:
parent
3a524b0597
commit
dcd0d1f6a1
|
@ -117,7 +117,9 @@ def load_tts_samples(
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||||
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
|
meta_data_eval = [
|
||||||
|
{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=W0223
|
||||||
class GradientReversalFunction(torch.autograd.Function):
|
class GradientReversalFunction(torch.autograd.Function):
|
||||||
"""Revert gradient without any further input modification.
|
"""Revert gradient without any further input modification.
|
||||||
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
|
Adapted from: https://github.com/Tomiinek/Multilingual_Text_to_Speech/"""
|
||||||
|
@ -30,16 +32,15 @@ class ReversalClassifier(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0):
|
||||||
super(ReversalClassifier, self).__init__()
|
super().__init__()
|
||||||
self._lambda = scale_factor
|
self._lambda = scale_factor
|
||||||
self._clipping = gradient_clipping_bounds
|
self._clipping = gradient_clipping_bounds
|
||||||
self._out_channels = out_channels
|
self._out_channels = out_channels
|
||||||
self._classifier = nn.Sequential(
|
self._classifier = nn.Sequential(
|
||||||
nn.Linear(in_channels, hidden_channels),
|
nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, out_channels)
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_channels, out_channels)
|
|
||||||
)
|
)
|
||||||
self.test = nn.Linear(in_channels, hidden_channels)
|
self.test = nn.Linear(in_channels, hidden_channels)
|
||||||
|
|
||||||
def forward(self, x, labels, x_mask=None):
|
def forward(self, x, labels, x_mask=None):
|
||||||
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
x = GradientReversalFunction.apply(x, self._lambda, self._clipping)
|
||||||
x = self._classifier(x)
|
x = self._classifier(x)
|
||||||
|
|
|
@ -12,9 +12,9 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
from TTS.model import BaseTrainerModel
|
from TTS.model import BaseTrainerModel
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
|
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
||||||
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager
|
||||||
from TTS.tts.utils.emotions import get_speech_style_balancer_weights
|
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,9 @@ from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
||||||
|
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
|
@ -33,9 +35,6 @@ from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
|
||||||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# IO / Feature extraction
|
# IO / Feature extraction
|
||||||
##############################
|
##############################
|
||||||
|
@ -503,7 +502,6 @@ class VitsArgs(Coqpit):
|
||||||
external_emotions_embs_file: str = None
|
external_emotions_embs_file: str = None
|
||||||
emotion_embedding_dim: int = 0
|
emotion_embedding_dim: int = 0
|
||||||
num_emotions: int = 0
|
num_emotions: int = 0
|
||||||
emotion_just_encoder: bool = False
|
|
||||||
|
|
||||||
# prosody encoder
|
# prosody encoder
|
||||||
use_prosody_encoder: bool = False
|
use_prosody_encoder: bool = False
|
||||||
|
@ -615,7 +613,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
|
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
|
||||||
|
|
||||||
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||||
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
|
@ -796,12 +794,6 @@ class Vits(BaseTTS):
|
||||||
if self.num_emotions > 0:
|
if self.num_emotions > 0:
|
||||||
print(" > initialization of emotion-embedding layers.")
|
print(" > initialization of emotion-embedding layers.")
|
||||||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||||
if not self.args.emotion_just_encoder:
|
|
||||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
|
||||||
|
|
||||||
if self.args.use_external_emotions_embeddings:
|
|
||||||
if not self.args.emotion_just_encoder:
|
|
||||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||||
|
@ -983,13 +975,6 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# concat the emotion embedding and speaker embedding
|
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
|
||||||
if g is None:
|
|
||||||
g = eg
|
|
||||||
else:
|
|
||||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb = None
|
lang_emb = None
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
|
@ -1004,15 +989,15 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||||
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
|
||||||
# print("Encoder input", x.shape)
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||||
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
|
|
||||||
# flow layers
|
# flow layers
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
# print("Y mask:", y_mask.shape)
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
if g_dp is None:
|
if g_dp is None:
|
||||||
g_dp = eg
|
g_dp = eg
|
||||||
else:
|
else:
|
||||||
|
@ -1130,13 +1115,6 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
if self.args.use_emotion_embedding and eid is not None and eg is None:
|
||||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# concat the emotion embedding and speaker embedding
|
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
|
||||||
if g is None:
|
|
||||||
g = eg
|
|
||||||
else:
|
|
||||||
g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb = None
|
lang_emb = None
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
|
@ -1154,7 +1132,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
if g_dp is None:
|
if g_dp is None:
|
||||||
g_dp = eg
|
g_dp = eg
|
||||||
else:
|
else:
|
||||||
|
@ -1176,9 +1154,7 @@ class Vits(BaseTTS):
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(x, x_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
x, x_mask, g=g_dp, lang_emb=lang_emb
|
|
||||||
)
|
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
|
@ -1195,13 +1171,23 @@ class Vits(BaseTTS):
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
||||||
|
|
||||||
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
|
outputs = {
|
||||||
|
"model_outputs": o,
|
||||||
|
"alignments": attn.squeeze(1),
|
||||||
|
"z": z,
|
||||||
|
"z_p": z_p,
|
||||||
|
"m_p": m_p,
|
||||||
|
"logs_p": logs_p,
|
||||||
|
"durations": w_ceil,
|
||||||
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def compute_style_feature(self, style_wav_path):
|
def compute_style_feature(self, style_wav_path):
|
||||||
style_wav, sr = torchaudio.load(style_wav_path)
|
style_wav, sr = torchaudio.load(style_wav_path)
|
||||||
if sr != self.config.audio.sample_rate:
|
if sr != self.config.audio.sample_rate:
|
||||||
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
|
raise RuntimeError(
|
||||||
|
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||||
|
)
|
||||||
y = wav_to_spec(
|
y = wav_to_spec(
|
||||||
style_wav,
|
style_wav,
|
||||||
self.config.audio.fft_size,
|
self.config.audio.fft_size,
|
||||||
|
@ -1371,7 +1357,7 @@ class Vits(BaseTTS):
|
||||||
or self.args.use_emotion_encoder_as_loss,
|
or self.args.use_emotion_encoder_as_loss,
|
||||||
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
gt_cons_emb=self.model_outputs_cache["gt_cons_emb"],
|
||||||
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
syn_cons_emb=self.model_outputs_cache["syn_cons_emb"],
|
||||||
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"]
|
loss_spk_reversal_classifier=self.model_outputs_cache["loss_spk_reversal_classifier"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.model_outputs_cache, loss_dict
|
return self.model_outputs_cache, loss_dict
|
||||||
|
@ -1539,7 +1525,11 @@ class Vits(BaseTTS):
|
||||||
emotion_ids = None
|
emotion_ids = None
|
||||||
|
|
||||||
# get numerical speaker ids from speaker names
|
# get numerical speaker ids from speaker names
|
||||||
if self.speaker_manager is not None and self.speaker_manager.ids and (self.args.use_speaker_embedding or self.args.use_prosody_encoder):
|
if (
|
||||||
|
self.speaker_manager is not None
|
||||||
|
and self.speaker_manager.ids
|
||||||
|
and (self.args.use_speaker_embedding or self.args.use_prosody_encoder)
|
||||||
|
):
|
||||||
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]
|
||||||
|
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from TTS.config import get_from_config_or_model_args_with_default
|
from TTS.config import get_from_config_or_model_args_with_default
|
||||||
|
|
|
@ -95,7 +95,9 @@ class SpeakerManager(EmbeddingManager):
|
||||||
SpeakerEncoder: Speaker encoder object.
|
SpeakerEncoder: Speaker encoder object.
|
||||||
"""
|
"""
|
||||||
speaker_manager = None
|
speaker_manager = None
|
||||||
if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
if get_from_config_or_model_args_with_default(
|
||||||
|
config, "use_speaker_embedding", False
|
||||||
|
) or get_from_config_or_model_args_with_default(config, "use_prosody_encoder", False):
|
||||||
if samples:
|
if samples:
|
||||||
speaker_manager = SpeakerManager(data_items=samples)
|
speaker_manager = SpeakerManager(data_items=samples)
|
||||||
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
|
if get_from_config_or_model_args_with_default(config, "speaker_file", None):
|
||||||
|
|
|
@ -168,10 +168,9 @@ def synthesis(
|
||||||
style_feature = style_wav
|
style_feature = style_wav
|
||||||
else:
|
else:
|
||||||
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
|
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
|
||||||
if hasattr(model, 'compute_style_feature'):
|
if hasattr(model, "compute_style_feature"):
|
||||||
style_feature = model.compute_style_feature(style_wav)
|
style_feature = model.compute_style_feature(style_wav)
|
||||||
|
|
||||||
|
|
||||||
# convert text to sequence of token IDs
|
# convert text to sequence of token IDs
|
||||||
text_inputs = np.asarray(
|
text_inputs = np.asarray(
|
||||||
model.tokenizer.text_to_ids(text, language=language_id),
|
model.tokenizer.text_to_ids(text, language=language_id),
|
||||||
|
|
Loading…
Reference in New Issue