Add prosody encoder inference support

This commit is contained in:
Edresson Casanova 2022-05-27 16:00:41 -03:00
parent 2568b722dd
commit bd35371944
5 changed files with 81 additions and 11 deletions

View File

@ -179,6 +179,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=None,
)
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
parser.add_argument(
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
)
@ -325,6 +326,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx,
style_speaker_name=args.style_speaker_name,
)
# save the results

View File

@ -981,7 +981,7 @@ class Vits(BaseTTS):
@staticmethod
def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid, eid, eg, pf = None, None, None, None, None, None
sid, g, lid, eid, eg, pf, ssid, ssg = None, None, None, None, None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"]
if sid.ndim == 0:
@ -1010,7 +1010,18 @@ class Vits(BaseTTS):
pf = aux_input["style_feature"]
if pf.ndim == 2:
pf = pf.unsqueeze_(0)
return sid, g, lid, eid, eg, pf
if "style_speaker_id" in aux_input and aux_input["style_speaker_id"] is not None:
ssid = aux_input["style_speaker_id"]
if ssid.ndim == 0:
ssid = ssid.unsqueeze_(0)
if "style_speaker_d_vector" in aux_input and aux_input["style_speaker_d_vector"] is not None:
ssg = F.normalize(aux_input["style_speaker_d_vector"]).unsqueeze(-1)
if ssg.ndim == 2:
ssg = ssg.unsqueeze_(0)
return sid, g, lid, eid, eg, pf, ssid, ssg
def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None)
@ -1130,7 +1141,7 @@ class Vits(BaseTTS):
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
"""
outputs = {}
sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, _, _, _ = self._set_cond_input(aux_input)
# speaker embedding
if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -1317,7 +1328,7 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]`
"""
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding
@ -1336,13 +1347,17 @@ class Vits(BaseTTS):
# prosody embedding
pros_emb = None
if self.args.use_prosody_encoder:
# speaker embedding for the style speaker
if self.args.use_speaker_embedding and ssid is not None:
ssg = self.emb_g(ssid).unsqueeze(-1)
# extract posterior encoder feature
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
if not self.args.use_prosody_encoder_z_p_input:
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2)
@ -1687,7 +1702,7 @@ class Vits(BaseTTS):
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
@ -1700,23 +1715,37 @@ class Vits(BaseTTS):
text, speaker_name, style_wav, language_name = sentence_info
elif len(sentence_info) == 5:
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
elif len(sentence_info) == 6:
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = sentence_info
else:
text = sentence_info
if style_wav and style_speaker_name is None:
raise RuntimeError(
f" [!] You must to provide the style_speaker_name for the style_wav !!"
)
# get speaker id/d_vector
speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embeddings()
else:
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
if style_wav is not None:
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.ids[speaker_name]
if style_wav is not None:
style_speaker_id = self.speaker_manager.ids[style_speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.ids[language_name]
@ -1740,6 +1769,8 @@ class Vits(BaseTTS):
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"style_speaker_id": style_speaker_id,
"style_speaker_d_vector": style_speaker_d_vector,
"d_vector": d_vector,
"language_id": language_id,
"language_name": language_name,
@ -1773,6 +1804,8 @@ class Vits(BaseTTS):
language_id=aux_inputs["language_id"],
emotion_embedding=aux_inputs["emotion_embedding"],
emotion_id=aux_inputs["emotion_ids"],
style_speaker_id=aux_inputs["style_speaker_id"],
style_speaker_d_vector=aux_inputs["style_speaker_d_vector"],
use_griffin_lim=True,
do_trim_silence=False,
).values()

View File

@ -31,6 +31,8 @@ def run_model_torch(
language_id: torch.Tensor = None,
emotion_id: torch.Tensor = None,
emotion_embedding: torch.Tensor = None,
style_speaker_id: torch.Tensor = None,
style_speaker_d_vector: torch.Tensor = None,
) -> Dict:
"""Run a torch model for inference. It does not support batch inference.
@ -60,6 +62,8 @@ def run_model_torch(
"language_ids": language_id,
"emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding,
"style_speaker_id": style_speaker_id,
"style_speaker_d_vector": style_speaker_d_vector,
},
)
return outputs
@ -128,6 +132,8 @@ def synthesis(
language_id=None,
emotion_id=None,
emotion_embedding=None,
style_speaker_id=None,
style_speaker_d_vector=None,
):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model.
@ -205,6 +211,13 @@ def synthesis(
if emotion_embedding is not None:
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
if style_speaker_id is not None:
style_speaker_id = id_to_torch(style_speaker_id, cuda=use_cuda)
if style_speaker_d_vector is not None:
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
if not isinstance(style_feature, dict):
# GST or Capacitron style mel
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
@ -229,6 +242,8 @@ def synthesis(
language_id=language_id,
emotion_id=emotion_id,
emotion_embedding=emotion_embedding,
style_speaker_id=style_speaker_id,
style_speaker_d_vector=style_speaker_d_vector,
)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()

View File

@ -216,6 +216,7 @@ class Synthesizer(object):
emotion_name=None,
source_emotion=None,
target_emotion=None,
style_speaker_name=None,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
@ -247,6 +248,8 @@ class Synthesizer(object):
# handle multi-speaker
speaker_embedding = None
speaker_id = None
style_speaker_id = None
style_speaker_embedding = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file:
@ -255,10 +258,20 @@ class Synthesizer(object):
speaker_name, num_samples=None, randomize=False
)
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
if style_speaker_name is not None:
style_speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
style_speaker_name, num_samples=None, randomize=False
)
style_speaker_embedding = np.array(style_speaker_embedding)[None, :] # [1 x embedding_dim]
else:
# get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
if style_speaker_name is not None:
style_speaker_id = self.tts_model.speaker_manager.ids[style_speaker_name]
elif not speaker_name and not speaker_wav:
raise ValueError(
" [!] Look like you use a multi-speaker model. "
@ -327,6 +340,9 @@ class Synthesizer(object):
if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
if style_wav is not None and style_speaker_name is None:
style_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(style_wav)
use_gl = self.vocoder_model is None
if not reference_wav:
@ -340,6 +356,8 @@ class Synthesizer(object):
speaker_id=speaker_id,
style_wav=style_wav,
style_text=style_text,
style_speaker_id=style_speaker_id,
style_speaker_d_vector=style_speaker_embedding,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
language_id=language_id,

View File

@ -26,7 +26,7 @@ config = VitsConfig(
print_step=1,
print_eval=True,
test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
],
)
# set audio config
@ -46,7 +46,7 @@ config.model_args.prosody_embedding_dim = 64
# active classifier
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = True
config.model_args.use_text_enc_emo_classifier = False
config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae"
@ -75,11 +75,13 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
style_speaker_name = "ljspeech-2"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path}"
print("Testing inference !")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path} --style_speaker_name {style_speaker_name}"
run_cli(inference_command)
# restore the model and continue training for one more epoch