Add prosody encoder params on config

This commit is contained in:
Edresson Casanova 2022-05-16 09:45:28 -03:00
parent 5271846d9c
commit 3a524b0597
8 changed files with 90 additions and 45 deletions

View File

@ -178,7 +178,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.",
default=None,
)
parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None)
parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
parser.add_argument(
"--list_speaker_idxs",
help="List available speaker ids for the defined multi-speaker model.",
@ -317,6 +317,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
args.speaker_idx,
args.language_idx,
args.speaker_wav,
style_wav=args.gst_style,
reference_wav=args.reference_wav,
reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx,

View File

@ -117,7 +117,7 @@ def load_tts_samples(
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
meta_data_eval = [{**item, **{"language": language, "speech_style": speech_style}} for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size)
meta_data_eval_all += meta_data_eval

View File

@ -189,7 +189,7 @@ class Tacotron(BaseTacotron):
encoder_outputs = self.encoder(inputs)
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
if self.num_speakers > 1:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim

View File

@ -215,7 +215,7 @@ class Tacotron2(BaseTacotron):
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_feature"], aux_input["d_vectors"])
if self.num_speakers > 1:
if not self.use_d_vector_file:

View File

@ -508,6 +508,8 @@ class VitsArgs(Coqpit):
# prosody encoder
use_prosody_encoder: bool = False
prosody_embedding_dim: int = 0
prosody_encoder_num_heads: int = 1
prosody_encoder_num_tokens: int = 5
detach_dp_input: bool = True
use_language_embedding: bool = False
@ -642,8 +644,8 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder:
self.prosody_encoder = GST(
num_mel=self.args.hidden_channels,
num_heads=1,
num_style_tokens=5,
num_heads=self.args.prosody_encoder_num_heads,
num_style_tokens=self.args.prosody_encoder_num_tokens,
gst_embedding_dim=self.args.prosody_embedding_dim,
)
self.speaker_reversal_classifier = ReversalClassifier(
@ -840,7 +842,7 @@ class Vits(BaseTTS):
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid, eid, eg = None, None, None, None, None
sid, g, lid, eid, eg, pf = None, None, None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
@ -865,7 +867,11 @@ class Vits(BaseTTS):
if eg.ndim == 2:
eg = eg.unsqueeze_(0)
return sid, g, lid, eid, eg
if "style_feature" in aux_input and aux_input["style_feature"] is not None:
pf = aux_input["style_feature"]
if pf.ndim == 2:
pf = pf.unsqueeze_(0)
return sid, g, lid, eid, eg, pf
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
@ -968,7 +974,7 @@ class Vits(BaseTTS):
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
"""
outputs = {}
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -998,14 +1004,14 @@ class Vits(BaseTTS):
if self.args.use_prosody_encoder:
pros_emb = self.prosody_encoder(z).transpose(1, 2)
_, l_pros_speaker = self.speaker_reversal_classifier(pros_emb.transpose(1, 2), sid, x_mask=None)
# print("Encoder input", x.shape)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# print("X shape:", x.shape, "m_p shape:", m_p.shape, "x_mask:", x_mask.shape, "x_lengths:", x_lengths.shape)
# flow layers
z_p = self.flow(z, y_mask, g=g)
# print("Y mask:", y_mask.shape)
# duration predictor
g_dp = g
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if g_dp is None:
g_dp = eg
@ -1092,6 +1098,7 @@ class Vits(BaseTTS):
"language_ids": None,
"emotion_embeddings": None,
"emotion_ids": None,
"style_feature": None,
},
): # pylint: disable=dangerous-default-value
"""
@ -1112,7 +1119,7 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
sid, g, lid, eid, eg = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding
@ -1135,29 +1142,42 @@ class Vits(BaseTTS):
if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg)
# prosody embedding
pros_emb = None
if self.args.use_prosody_encoder:
# extract posterior encoder feature
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, _ = self.posterior_encoder(pf, pf_lengths, g=g)
pros_emb = self.prosody_encoder(z_pro).transpose(1, 2)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, emo_emb=eg, pros_emb=pros_emb)
# duration predictor
g_dp = g if self.args.condition_dp_on_speaker else None
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and self.args.emotion_just_encoder:
if g is None:
if g_dp is None:
g_dp = eg
else:
g_dp = torch.cat([g, eg], dim=1) # [b, h1+h2, 1]
else:
g_dp = g
g_dp = torch.cat([g_dp, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dp is None:
g_dp = pros_emb
else:
g_dp = torch.cat([g_dp, pros_emb], dim=1) # [b, h1+h2, 1]
if self.args.use_sdp:
logw = self.duration_predictor(
x,
x_mask,
g=g_dp if self.args.condition_dp_on_speaker else None,
g=g_dp,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
)
else:
logw = self.duration_predictor(
x, x_mask, g=g_dp if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
x, x_mask, g=g_dp, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale
@ -1175,9 +1195,22 @@ class Vits(BaseTTS):
z = self.flow(z_p, y_mask, g=g, reverse=True)
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p}
outputs = {"model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, "durations": w_ceil}
return outputs
def compute_style_feature(self, style_wav_path):
style_wav, sr = torchaudio.load(style_wav_path)
if sr != self.config.audio.sample_rate:
raise RuntimeError(" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!")
y = wav_to_spec(
style_wav,
self.config.audio.fft_size,
self.config.audio.hop_length,
self.config.audio.win_length,
center=False,
)
return y
@torch.no_grad()
def inference_voice_conversion(
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None

View File

@ -14,18 +14,18 @@ def numpy_to_torch(np_array, dtype, cuda=False):
return tensor
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
def compute_style_feature(style_wav, ap, cuda=False):
style_feature = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda:
return style_mel.cuda()
return style_mel
return style_feature.cuda()
return style_feature
def run_model_torch(
model: nn.Module,
inputs: torch.Tensor,
speaker_id: int = None,
style_mel: torch.Tensor = None,
style_feature: torch.Tensor = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
emotion_id: torch.Tensor = None,
@ -37,7 +37,7 @@ def run_model_torch(
model (nn.Module): The model to run inference.
inputs (torch.Tensor): Input tensor with character ids.
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
style_feature (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
Returns:
@ -54,7 +54,7 @@ def run_model_torch(
"x_lengths": input_lengths,
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"style_feature": style_feature,
"language_ids": language_id,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
@ -161,12 +161,17 @@ def synthesis(
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# GST processing
style_mel = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_feature = None
if style_wav is not None:
if CONFIG.has("gst") and CONFIG.gst:
if isinstance(style_wav, dict):
style_feature = style_wav
else:
style_feature = compute_style_feature(style_wav, model.ap, cuda=use_cuda)
if hasattr(model, 'compute_style_feature'):
style_feature = model.compute_style_feature(style_wav)
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),
@ -188,8 +193,8 @@ def synthesis(
if emotion_embedding is not None:
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
if not isinstance(style_mel, dict):
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
if not isinstance(style_feature, dict):
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
@ -197,7 +202,7 @@ def synthesis(
model,
text_inputs,
speaker_id,
style_mel,
style_feature,
d_vector=d_vector,
language_id=language_id,
emotion_id=emotion_id,

View File

@ -47,7 +47,7 @@ config.model_args.use_emotion_embedding = False
config.model_args.emotion_embedding_dim = 256
config.model_args.emotion_just_encoder = True
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.use_style_weighted_sampler = True
# consistency loss
# config.model_args.use_emotion_encoder_as_loss = True
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
@ -64,6 +64,13 @@ command_train = (
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.datasets.0.speech_style style1 "
"--coqpit.datasets.1.name ljspeech_test "
"--coqpit.datasets.1.meta_file_train metadata.csv "
"--coqpit.datasets.1.meta_file_val metadata.csv "
"--coqpit.datasets.1.path tests/data/ljspeech "
"--coqpit.datasets.1.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.datasets.1.speech_style style2 "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)

View File

@ -26,7 +26,7 @@ config = VitsConfig(
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"],
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
],
)
# set audio config
@ -38,11 +38,11 @@ config.model_args.use_speaker_embedding = True
config.model_args.use_d_vector_file = False
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
config.model_args.speaker_embedding_channels = 128
config.model_args.d_vector_dim = 256
config.model_args.d_vector_dim = 128
# prosody embedding
config.model_args.use_prosody_encoder = True
config.model_args.prosody_embedding_dim = 256
config.model_args.prosody_embedding_dim = 64
config.save_json(config_path)
@ -67,12 +67,11 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
emotion_id = "ljspeech-3"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
continue_emotion_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --emotion_idx {emotion_id} --speakers_file_path {continue_speakers_path} --emotions_file_path {continue_emotion_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --gst_style {style_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch