mirror of https://github.com/coqui-ai/TTS.git
Add prosody encoder training support
This commit is contained in:
parent
f31ba25233
commit
8a3396d9c1
|
@ -373,17 +373,17 @@ def esd(root_path, meta_files, ignored_speakers=None):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(meta_file, "r", encoding="latin-1") as file_text:
|
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||||
try:
|
try:
|
||||||
metadata = file_text.readlines()
|
metadata = file_text.readlines()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"The file {meta_file} break the import with the following error: ")
|
print(f"The file {meta_file} break the import with the following error: ")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
for data in metadata:
|
for data in metadata:
|
||||||
# this dataset have problems with csv separator, some files use just space others \t
|
# this dataset have problems with csv separator, some files use just space others \t
|
||||||
data = data.replace("\n", "").replace("\t", " ")
|
data = data.replace("\n", "").replace("\t", " ")
|
||||||
if not data:
|
if not data:
|
||||||
|
print(meta_file, data)
|
||||||
continue
|
continue
|
||||||
splits = data.split(" ")
|
splits = data.split(" ")
|
||||||
|
|
||||||
|
@ -391,10 +391,12 @@ def esd(root_path, meta_files, ignored_speakers=None):
|
||||||
emotion_id = splits[-1]
|
emotion_id = splits[-1]
|
||||||
# all except the first and last position is the sentence
|
# all except the first and last position is the sentence
|
||||||
text = " ".join(splits[1:-1])
|
text = " ".join(splits[1:-1])
|
||||||
|
|
||||||
for split in meta_files:
|
for split in meta_files:
|
||||||
wav_file = os.path.join(root_path, speaker_id, emotion_id, split, file_id + ".wav")
|
wav_file = os.path.join(root_path, speaker_id, emotion_id, split, file_id + ".wav")
|
||||||
if os.path.exists(wav_file):
|
if os.path.exists(wav_file):
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "ESD_" + speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": "ESD_" + speaker_id})
|
||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ class TextEncoder(nn.Module):
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
language_emb_dim: int = None,
|
language_emb_dim: int = None,
|
||||||
emotion_emb_dim: int = None,
|
emotion_emb_dim: int = None,
|
||||||
|
prosody_emb_dim: int = None,
|
||||||
):
|
):
|
||||||
"""Text Encoder for VITS model.
|
"""Text Encoder for VITS model.
|
||||||
|
|
||||||
|
@ -66,6 +67,9 @@ class TextEncoder(nn.Module):
|
||||||
if emotion_emb_dim:
|
if emotion_emb_dim:
|
||||||
hidden_channels += emotion_emb_dim
|
hidden_channels += emotion_emb_dim
|
||||||
|
|
||||||
|
if prosody_emb_dim:
|
||||||
|
hidden_channels += prosody_emb_dim
|
||||||
|
|
||||||
self.encoder = RelativePositionTransformer(
|
self.encoder = RelativePositionTransformer(
|
||||||
in_channels=hidden_channels,
|
in_channels=hidden_channels,
|
||||||
out_channels=hidden_channels,
|
out_channels=hidden_channels,
|
||||||
|
@ -81,7 +85,7 @@ class TextEncoder(nn.Module):
|
||||||
|
|
||||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None):
|
def forward(self, x, x_lengths, lang_emb=None, emo_emb=None, pros_emb=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, T]`
|
- x: :math:`[B, T]`
|
||||||
|
@ -98,6 +102,9 @@ class TextEncoder(nn.Module):
|
||||||
if emo_emb is not None:
|
if emo_emb is not None:
|
||||||
x = torch.cat((x, emo_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
x = torch.cat((x, emo_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
|
if pros_emb is not None:
|
||||||
|
x = torch.cat((x, pros_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
|
||||||
|
|
||||||
x = torch.transpose(x, 1, -1) # [b, h, t]
|
x = torch.transpose(x, 1, -1) # [b, h, t]
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t]
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,8 @@ from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# IO / Feature extraction
|
# IO / Feature extraction
|
||||||
##############################
|
##############################
|
||||||
|
@ -500,6 +502,11 @@ class VitsArgs(Coqpit):
|
||||||
external_emotions_embs_file: str = None
|
external_emotions_embs_file: str = None
|
||||||
emotion_embedding_dim: int = 0
|
emotion_embedding_dim: int = 0
|
||||||
num_emotions: int = 0
|
num_emotions: int = 0
|
||||||
|
emotion_just_encoder: bool = False
|
||||||
|
|
||||||
|
# prosody encoder
|
||||||
|
use_prosody_encoder: bool = False
|
||||||
|
prosody_embedding_dim: int = 0
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
|
@ -581,6 +588,7 @@ class Vits(BaseTTS):
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
emotion_emb_dim=self.args.emotion_embedding_dim,
|
emotion_emb_dim=self.args.emotion_embedding_dim,
|
||||||
|
prosody_emb_dim=self.args.prosody_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -602,26 +610,42 @@ class Vits(BaseTTS):
|
||||||
cond_channels=self.cond_embedding_dim,
|
cond_channels=self.cond_embedding_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dp_cond_embedding_dim = self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0
|
||||||
|
|
||||||
|
if self.args.emotion_just_encoder and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
|
dp_cond_embedding_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
dp_cond_embedding_dim += self.args.prosody_embedding_dim
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
self.duration_predictor = StochasticDurationPredictor(
|
self.duration_predictor = StochasticDurationPredictor(
|
||||||
self.args.hidden_channels + self.args.emotion_embedding_dim,
|
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||||
192,
|
192,
|
||||||
3,
|
3,
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
4,
|
4,
|
||||||
cond_channels=self.cond_embedding_dim if self.args.condition_dp_on_speaker else 0,
|
cond_channels=dp_cond_embedding_dim,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
self.args.hidden_channels + self.args.emotion_embedding_dim,
|
self.args.hidden_channels + self.args.emotion_embedding_dim + self.args.prosody_embedding_dim,
|
||||||
256,
|
256,
|
||||||
3,
|
3,
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
cond_channels=self.cond_embedding_dim,
|
cond_channels=dp_cond_embedding_dim,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
self.prosody_encoder = GST(
|
||||||
|
num_mel=self.args.hidden_channels,
|
||||||
|
num_heads=1,
|
||||||
|
num_style_tokens=5,
|
||||||
|
gst_embedding_dim=self.args.prosody_embedding_dim,
|
||||||
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
self.args.hidden_channels,
|
self.args.hidden_channels,
|
||||||
1,
|
1,
|
||||||
|
@ -764,10 +788,12 @@ class Vits(BaseTTS):
|
||||||
if self.num_emotions > 0:
|
if self.num_emotions > 0:
|
||||||
print(" > initialization of emotion-embedding layers.")
|
print(" > initialization of emotion-embedding layers.")
|
||||||
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
self.emb_emotion = nn.Embedding(self.num_emotions, self.args.emotion_embedding_dim)
|
||||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
if not self.args.emotion_just_encoder:
|
||||||
|
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
if self.args.use_external_emotions_embeddings:
|
if self.args.use_external_emotions_embeddings:
|
||||||
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
if not self.args.emotion_just_encoder:
|
||||||
|
self.cond_embedding_dim += self.args.emotion_embedding_dim
|
||||||
|
|
||||||
def get_aux_input(self, aux_input: Dict):
|
def get_aux_input(self, aux_input: Dict):
|
||||||
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
|
||||||
|
@ -946,7 +972,7 @@ class Vits(BaseTTS):
|
||||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# concat the emotion embedding and speaker embedding
|
# concat the emotion embedding and speaker embedding
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||||
if g is None:
|
if g is None:
|
||||||
g = eg
|
g = eg
|
||||||
else:
|
else:
|
||||||
|
@ -957,16 +983,34 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
|
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
# prosody embedding
|
||||||
|
pros_emb = None
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
pros_emb = self.prosody_encoder(z).transpose(1, 2)
|
||||||
|
|
||||||
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
|
||||||
|
|
||||||
# flow layers
|
# flow layers
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
g_dp = g
|
||||||
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||||
|
if g_dp is None:
|
||||||
|
g_dp = eg
|
||||||
|
else:
|
||||||
|
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
|
if self.args.use_prosody_encoder:
|
||||||
|
if g_dp is None:
|
||||||
|
g_dp = pros_emb
|
||||||
|
else:
|
||||||
|
g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||||
|
|
||||||
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||||
|
|
||||||
# expand prior
|
# expand prior
|
||||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||||
|
@ -1071,7 +1115,7 @@ class Vits(BaseTTS):
|
||||||
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# concat the emotion embedding and speaker embedding
|
# concat the emotion embedding and speaker embedding
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and not self.args.emotion_just_encoder:
|
||||||
if g is None:
|
if g is None:
|
||||||
g = eg
|
g = eg
|
||||||
else:
|
else:
|
||||||
|
@ -1084,18 +1128,27 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
|
||||||
|
|
||||||
|
# duration predictor
|
||||||
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
|
||||||
|
if g is None:
|
||||||
|
g_dp = eg
|
||||||
|
else:
|
||||||
|
g_dp = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
|
||||||
|
else:
|
||||||
|
g_dp = g
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
x,
|
x,
|
||||||
x_mask,
|
x_mask,
|
||||||
g=g if self.args.condition_dp_on_speaker else None,
|
g=g_dp if self.args.condition_dp_on_speaker else None,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
noise_scale=self.inference_noise_scale_dp,
|
noise_scale=self.inference_noise_scale_dp,
|
||||||
lang_emb=lang_emb,
|
lang_emb=lang_emb,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(
|
logw = self.duration_predictor(
|
||||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
x, x_mask, g=g_dp if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||||
)
|
)
|
||||||
|
|
||||||
w = torch.exp(logw) * x_mask * self.length_scale
|
w = torch.exp(logw) * x_mask * self.length_scale
|
||||||
|
|
|
@ -43,6 +43,7 @@ config.model_args.d_vector_dim = 256
|
||||||
# emotion
|
# emotion
|
||||||
config.model_args.use_external_emotions_embeddings = False
|
config.model_args.use_external_emotions_embeddings = False
|
||||||
config.model_args.use_emotion_embedding = True
|
config.model_args.use_emotion_embedding = True
|
||||||
|
config.model_args.emotion_just_encoder = False
|
||||||
config.model_args.emotion_embedding_dim = 256
|
config.model_args.emotion_embedding_dim = 256
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from trainer import get_last_checkpoint
|
||||||
|
|
||||||
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
|
config = VitsConfig(
|
||||||
|
batch_size=2,
|
||||||
|
eval_batch_size=2,
|
||||||
|
num_loader_workers=0,
|
||||||
|
num_eval_loader_workers=0,
|
||||||
|
text_cleaner="english_cleaners",
|
||||||
|
use_phonemes=True,
|
||||||
|
phoneme_language="en-us",
|
||||||
|
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||||
|
run_eval=True,
|
||||||
|
test_delay_epochs=-1,
|
||||||
|
epochs=1,
|
||||||
|
print_step=1,
|
||||||
|
print_eval=True,
|
||||||
|
test_sentences=[
|
||||||
|
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# set audio config
|
||||||
|
config.audio.do_trim_silence = True
|
||||||
|
config.audio.trim_db = 60
|
||||||
|
|
||||||
|
# active multispeaker d-vec mode
|
||||||
|
config.model_args.use_speaker_embedding = True
|
||||||
|
config.model_args.use_d_vector_file = False
|
||||||
|
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||||
|
config.model_args.speaker_embedding_channels = 128
|
||||||
|
config.model_args.d_vector_dim = 256
|
||||||
|
|
||||||
|
# prosody embedding
|
||||||
|
config.model_args.use_prosody_encoder = True
|
||||||
|
config.model_args.prosody_embedding_dim = 256
|
||||||
|
|
||||||
|
config.save_json(config_path)
|
||||||
|
|
||||||
|
# train the model for one epoch
|
||||||
|
command_train = (
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||||
|
f"--coqpit.output_path {output_path} "
|
||||||
|
"--coqpit.datasets.0.name ljspeech_test "
|
||||||
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
|
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||||
|
"--coqpit.test_delay_epochs 0"
|
||||||
|
)
|
||||||
|
run_cli(command_train)
|
||||||
|
|
||||||
|
# Find latest folder
|
||||||
|
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||||
|
|
||||||
|
# Inference using TTS API
|
||||||
|
continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
|
speaker_id = "ljspeech-1"
|
||||||
|
emotion_id = "ljspeech-3"
|
||||||
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
continue_emotion_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||||
|
run_cli(inference_command)
|
||||||
|
|
||||||
|
# restore the model and continue training for one more epoch
|
||||||
|
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||||
|
run_cli(command_train)
|
||||||
|
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue