Add prosody encoder inference support

This commit is contained in:
Edresson Casanova 2022-05-27 16:00:41 -03:00
parent 010f847929
commit a822f21b78
5 changed files with 81 additions and 11 deletions

View File

@ -179,6 +179,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=None, default=None,
) )
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None) parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
parser.add_argument( parser.add_argument(
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None "--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
) )
@ -325,6 +326,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
style_text=args.capacitron_style_text, style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx, reference_speaker_name=args.reference_speaker_idx,
emotion_name=args.emotion_idx, emotion_name=args.emotion_idx,
style_speaker_name=args.style_speaker_name,
) )
# save the results # save the results

View File

@ -981,7 +981,7 @@ class Vits(BaseTTS):
@staticmethod @staticmethod
def _set_cond_input(aux_input: Dict): def _set_cond_input(aux_input: Dict):
"""Set the speaker conditioning input based on the multi-speaker mode.""" """Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid, eid, eg, pf = None, None, None, None, None, None sid, g, lid, eid, eg, pf, ssid, ssg = None, None, None, None, None, None, None, None
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
sid = aux_input["speaker_ids"] sid = aux_input["speaker_ids"]
if sid.ndim == 0: if sid.ndim == 0:
@ -1010,7 +1010,18 @@ class Vits(BaseTTS):
pf = aux_input["style_feature"] pf = aux_input["style_feature"]
if pf.ndim == 2: if pf.ndim == 2:
pf = pf.unsqueeze_(0) pf = pf.unsqueeze_(0)
return sid, g, lid, eid, eg, pf
if "style_speaker_id" in aux_input and aux_input["style_speaker_id"] is not None:
ssid = aux_input["style_speaker_id"]
if ssid.ndim == 0:
ssid = ssid.unsqueeze_(0)
if "style_speaker_d_vector" in aux_input and aux_input["style_speaker_d_vector"] is not None:
ssg = F.normalize(aux_input["style_speaker_d_vector"]).unsqueeze(-1)
if ssg.ndim == 2:
ssg = ssg.unsqueeze_(0)
return sid, g, lid, eid, eg, pf, ssid, ssg
def _set_speaker_input(self, aux_input: Dict): def _set_speaker_input(self, aux_input: Dict):
d_vectors = aux_input.get("d_vectors", None) d_vectors = aux_input.get("d_vectors", None)
@ -1130,7 +1141,7 @@ class Vits(BaseTTS):
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]` - syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
""" """
outputs = {} outputs = {}
sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input) sid, g, lid, eid, eg, _, _, _ = self._set_cond_input(aux_input)
# speaker embedding # speaker embedding
if self.args.use_speaker_embedding and sid is not None: if self.args.use_speaker_embedding and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@ -1317,7 +1328,7 @@ class Vits(BaseTTS):
- m_p: :math:`[B, C, T_dec]` - m_p: :math:`[B, C, T_dec]`
- logs_p: :math:`[B, C, T_dec]` - logs_p: :math:`[B, C, T_dec]`
""" """
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input) sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
x_lengths = self._set_x_lengths(x, aux_input) x_lengths = self._set_x_lengths(x, aux_input)
# speaker embedding # speaker embedding
@ -1336,13 +1347,17 @@ class Vits(BaseTTS):
# prosody embedding # prosody embedding
pros_emb = None pros_emb = None
if self.args.use_prosody_encoder: if self.args.use_prosody_encoder:
# speaker embedding for the style speaker
if self.args.use_speaker_embedding and ssid is not None:
ssg = self.emb_g(ssid).unsqueeze(-1)
# extract posterior encoder feature # extract posterior encoder feature
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device) pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g) z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
if not self.args.use_prosody_encoder_z_p_input: if not self.args.use_prosody_encoder_z_p_input:
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths) pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
else: else:
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g) z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths) pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
pros_emb = pros_emb.transpose(1, 2) pros_emb = pros_emb.transpose(1, 2)
@ -1687,7 +1702,7 @@ class Vits(BaseTTS):
config = self.config config = self.config
# extract speaker and language info # extract speaker and language info
text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
if isinstance(sentence_info, list): if isinstance(sentence_info, list):
if len(sentence_info) == 1: if len(sentence_info) == 1:
@ -1700,23 +1715,37 @@ class Vits(BaseTTS):
text, speaker_name, style_wav, language_name = sentence_info text, speaker_name, style_wav, language_name = sentence_info
elif len(sentence_info) == 5: elif len(sentence_info) == 5:
text, speaker_name, style_wav, language_name, emotion_name = sentence_info text, speaker_name, style_wav, language_name, emotion_name = sentence_info
elif len(sentence_info) == 6:
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = sentence_info
else: else:
text = sentence_info text = sentence_info
if style_wav and style_speaker_name is None:
raise RuntimeError(
f" [!] You must to provide the style_speaker_name for the style_wav !!"
)
# get speaker id/d_vector # get speaker id/d_vector
speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
if hasattr(self, "speaker_manager"): if hasattr(self, "speaker_manager"):
if config.use_d_vector_file: if config.use_d_vector_file:
if speaker_name is None: if speaker_name is None:
d_vector = self.speaker_manager.get_random_embeddings() d_vector = self.speaker_manager.get_random_embeddings()
else: else:
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
if style_wav is not None:
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
elif config.use_speaker_embedding: elif config.use_speaker_embedding:
if speaker_name is None: if speaker_name is None:
speaker_id = self.speaker_manager.get_random_id() speaker_id = self.speaker_manager.get_random_id()
else: else:
speaker_id = self.speaker_manager.ids[speaker_name] speaker_id = self.speaker_manager.ids[speaker_name]
if style_wav is not None:
style_speaker_id = self.speaker_manager.ids[style_speaker_name]
# get language id # get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.ids[language_name] language_id = self.language_manager.ids[language_name]
@ -1740,6 +1769,8 @@ class Vits(BaseTTS):
"text": text, "text": text,
"speaker_id": speaker_id, "speaker_id": speaker_id,
"style_wav": style_wav, "style_wav": style_wav,
"style_speaker_id": style_speaker_id,
"style_speaker_d_vector": style_speaker_d_vector,
"d_vector": d_vector, "d_vector": d_vector,
"language_id": language_id, "language_id": language_id,
"language_name": language_name, "language_name": language_name,
@ -1773,6 +1804,8 @@ class Vits(BaseTTS):
language_id=aux_inputs["language_id"], language_id=aux_inputs["language_id"],
emotion_embedding=aux_inputs["emotion_embedding"], emotion_embedding=aux_inputs["emotion_embedding"],
emotion_id=aux_inputs["emotion_ids"], emotion_id=aux_inputs["emotion_ids"],
style_speaker_id=aux_inputs["style_speaker_id"],
style_speaker_d_vector=aux_inputs["style_speaker_d_vector"],
use_griffin_lim=True, use_griffin_lim=True,
do_trim_silence=False, do_trim_silence=False,
).values() ).values()

View File

@ -31,6 +31,8 @@ def run_model_torch(
language_id: torch.Tensor = None, language_id: torch.Tensor = None,
emotion_id: torch.Tensor = None, emotion_id: torch.Tensor = None,
emotion_embedding: torch.Tensor = None, emotion_embedding: torch.Tensor = None,
style_speaker_id: torch.Tensor = None,
style_speaker_d_vector: torch.Tensor = None,
) -> Dict: ) -> Dict:
"""Run a torch model for inference. It does not support batch inference. """Run a torch model for inference. It does not support batch inference.
@ -60,6 +62,8 @@ def run_model_torch(
"language_ids": language_id, "language_ids": language_id,
"emotion_ids": emotion_id, "emotion_ids": emotion_id,
"emotion_embeddings": emotion_embedding, "emotion_embeddings": emotion_embedding,
"style_speaker_id": style_speaker_id,
"style_speaker_d_vector": style_speaker_d_vector,
}, },
) )
return outputs return outputs
@ -128,6 +132,8 @@ def synthesis(
language_id=None, language_id=None,
emotion_id=None, emotion_id=None,
emotion_embedding=None, emotion_embedding=None,
style_speaker_id=None,
style_speaker_d_vector=None,
): ):
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
the vocoder model. the vocoder model.
@ -205,6 +211,13 @@ def synthesis(
if emotion_embedding is not None: if emotion_embedding is not None:
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda) emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
if style_speaker_id is not None:
style_speaker_id = id_to_torch(style_speaker_id, cuda=use_cuda)
if style_speaker_d_vector is not None:
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
if not isinstance(style_feature, dict): if not isinstance(style_feature, dict):
# GST or Capacitron style mel # GST or Capacitron style mel
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda) style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
@ -229,6 +242,8 @@ def synthesis(
language_id=language_id, language_id=language_id,
emotion_id=emotion_id, emotion_id=emotion_id,
emotion_embedding=emotion_embedding, emotion_embedding=emotion_embedding,
style_speaker_id=style_speaker_id,
style_speaker_d_vector=style_speaker_d_vector,
) )
model_outputs = outputs["model_outputs"] model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy() model_outputs = model_outputs[0].data.cpu().numpy()

View File

@ -216,6 +216,7 @@ class Synthesizer(object):
emotion_name=None, emotion_name=None,
source_emotion=None, source_emotion=None,
target_emotion=None, target_emotion=None,
style_speaker_name=None,
) -> List[int]: ) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech. """🐸 TTS magic. Run all the models and generate speech.
@ -247,6 +248,8 @@ class Synthesizer(object):
# handle multi-speaker # handle multi-speaker
speaker_embedding = None speaker_embedding = None
speaker_id = None speaker_id = None
style_speaker_id = None
style_speaker_embedding = None
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
if speaker_name and isinstance(speaker_name, str): if speaker_name and isinstance(speaker_name, str):
if self.tts_config.use_d_vector_file: if self.tts_config.use_d_vector_file:
@ -255,10 +258,20 @@ class Synthesizer(object):
speaker_name, num_samples=None, randomize=False speaker_name, num_samples=None, randomize=False
) )
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
if style_speaker_name is not None:
style_speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
style_speaker_name, num_samples=None, randomize=False
)
style_speaker_embedding = np.array(style_speaker_embedding)[None, :] # [1 x embedding_dim]
else: else:
# get speaker idx from the speaker name # get speaker idx from the speaker name
speaker_id = self.tts_model.speaker_manager.ids[speaker_name] speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
if style_speaker_name is not None:
style_speaker_id = self.tts_model.speaker_manager.ids[style_speaker_name]
elif not speaker_name and not speaker_wav: elif not speaker_name and not speaker_wav:
raise ValueError( raise ValueError(
" [!] Look like you use a multi-speaker model. " " [!] Look like you use a multi-speaker model. "
@ -327,6 +340,9 @@ class Synthesizer(object):
if speaker_wav is not None: if speaker_wav is not None:
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
if style_wav is not None and style_speaker_name is None:
style_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(style_wav)
use_gl = self.vocoder_model is None use_gl = self.vocoder_model is None
if not reference_wav: if not reference_wav:
@ -340,6 +356,8 @@ class Synthesizer(object):
speaker_id=speaker_id, speaker_id=speaker_id,
style_wav=style_wav, style_wav=style_wav,
style_text=style_text, style_text=style_text,
style_speaker_id=style_speaker_id,
style_speaker_d_vector=style_speaker_embedding,
use_griffin_lim=use_gl, use_griffin_lim=use_gl,
d_vector=speaker_embedding, d_vector=speaker_embedding,
language_id=language_id, language_id=language_id,

View File

@ -26,7 +26,7 @@ config = VitsConfig(
print_step=1, print_step=1,
print_eval=True, print_eval=True,
test_sentences=[ test_sentences=[
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None], ["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
], ],
) )
# set audio config # set audio config
@ -46,7 +46,7 @@ config.model_args.prosody_embedding_dim = 64
# active classifier # active classifier
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
config.model_args.use_prosody_enc_emo_classifier = False config.model_args.use_prosody_enc_emo_classifier = False
config.model_args.use_text_enc_emo_classifier = True config.model_args.use_text_enc_emo_classifier = False
config.model_args.use_prosody_encoder_z_p_input = True config.model_args.use_prosody_encoder_z_p_input = True
config.model_args.prosody_encoder_type = "vae" config.model_args.prosody_encoder_type = "vae"
@ -75,11 +75,13 @@ continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path) continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav") out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1" speaker_id = "ljspeech-1"
style_speaker_name = "ljspeech-2"
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav" style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
continue_speakers_path = os.path.join(continue_path, "speakers.json") continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path}" print("Testing inference !")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path} --style_speaker_name {style_speaker_name}"
run_cli(inference_command) run_cli(inference_command)
# restore the model and continue training for one more epoch # restore the model and continue training for one more epoch