mirror of https://github.com/coqui-ai/TTS.git
Add prosody encoder inference support
This commit is contained in:
parent
010f847929
commit
a822f21b78
|
@ -179,6 +179,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
||||||
|
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
||||||
)
|
)
|
||||||
|
@ -325,6 +326,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
style_text=args.capacitron_style_text,
|
style_text=args.capacitron_style_text,
|
||||||
reference_speaker_name=args.reference_speaker_idx,
|
reference_speaker_name=args.reference_speaker_idx,
|
||||||
emotion_name=args.emotion_idx,
|
emotion_name=args.emotion_idx,
|
||||||
|
style_speaker_name=args.style_speaker_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
|
|
|
@ -981,7 +981,7 @@ class Vits(BaseTTS):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_cond_input(aux_input: Dict):
|
def _set_cond_input(aux_input: Dict):
|
||||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||||
sid, g, lid, eid, eg, pf = None, None, None, None, None, None
|
sid, g, lid, eid, eg, pf, ssid, ssg = None, None, None, None, None, None, None, None
|
||||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||||
sid = aux_input["speaker_ids"]
|
sid = aux_input["speaker_ids"]
|
||||||
if sid.ndim == 0:
|
if sid.ndim == 0:
|
||||||
|
@ -1010,7 +1010,18 @@ class Vits(BaseTTS):
|
||||||
pf = aux_input["style_feature"]
|
pf = aux_input["style_feature"]
|
||||||
if pf.ndim == 2:
|
if pf.ndim == 2:
|
||||||
pf = pf.unsqueeze_(0)
|
pf = pf.unsqueeze_(0)
|
||||||
return sid, g, lid, eid, eg, pf
|
|
||||||
|
if "style_speaker_id" in aux_input and aux_input["style_speaker_id"] is not None:
|
||||||
|
ssid = aux_input["style_speaker_id"]
|
||||||
|
if ssid.ndim == 0:
|
||||||
|
ssid = ssid.unsqueeze_(0)
|
||||||
|
|
||||||
|
if "style_speaker_d_vector" in aux_input and aux_input["style_speaker_d_vector"] is not None:
|
||||||
|
ssg = F.normalize(aux_input["style_speaker_d_vector"]).unsqueeze(-1)
|
||||||
|
if ssg.ndim == 2:
|
||||||
|
ssg = ssg.unsqueeze_(0)
|
||||||
|
|
||||||
|
return sid, g, lid, eid, eg, pf, ssid, ssg
|
||||||
|
|
||||||
def _set_speaker_input(self, aux_input: Dict):
|
def _set_speaker_input(self, aux_input: Dict):
|
||||||
d_vectors = aux_input.get("d_vectors", None)
|
d_vectors = aux_input.get("d_vectors", None)
|
||||||
|
@ -1130,7 +1141,7 @@ class Vits(BaseTTS):
|
||||||
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||||
"""
|
"""
|
||||||
outputs = {}
|
outputs = {}
|
||||||
sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg, _, _, _ = self._set_cond_input(aux_input)
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.args.use_speaker_embedding and sid is not None:
|
if self.args.use_speaker_embedding and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
@ -1317,7 +1328,7 @@ class Vits(BaseTTS):
|
||||||
- m_p: :math:`[B, C, T_dec]`
|
- m_p: :math:`[B, C, T_dec]`
|
||||||
- logs_p: :math:`[B, C, T_dec]`
|
- logs_p: :math:`[B, C, T_dec]`
|
||||||
"""
|
"""
|
||||||
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
|
sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
|
||||||
x_lengths = self._set_x_lengths(x, aux_input)
|
x_lengths = self._set_x_lengths(x, aux_input)
|
||||||
|
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
|
@ -1336,13 +1347,17 @@ class Vits(BaseTTS):
|
||||||
# prosody embedding
|
# prosody embedding
|
||||||
pros_emb = None
|
pros_emb = None
|
||||||
if self.args.use_prosody_encoder:
|
if self.args.use_prosody_encoder:
|
||||||
|
# speaker embedding for the style speaker
|
||||||
|
if self.args.use_speaker_embedding and ssid is not None:
|
||||||
|
ssg = self.emb_g(ssid).unsqueeze(-1)
|
||||||
|
|
||||||
# extract posterior encoder feature
|
# extract posterior encoder feature
|
||||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
|
||||||
if not self.args.use_prosody_encoder_z_p_input:
|
if not self.args.use_prosody_encoder_z_p_input:
|
||||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
|
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
|
||||||
else:
|
else:
|
||||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||||
|
|
||||||
pros_emb = pros_emb.transpose(1, 2)
|
pros_emb = pros_emb.transpose(1, 2)
|
||||||
|
@ -1687,7 +1702,7 @@ class Vits(BaseTTS):
|
||||||
config = self.config
|
config = self.config
|
||||||
|
|
||||||
# extract speaker and language info
|
# extract speaker and language info
|
||||||
text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None
|
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
|
||||||
|
|
||||||
if isinstance(sentence_info, list):
|
if isinstance(sentence_info, list):
|
||||||
if len(sentence_info) == 1:
|
if len(sentence_info) == 1:
|
||||||
|
@ -1700,23 +1715,37 @@ class Vits(BaseTTS):
|
||||||
text, speaker_name, style_wav, language_name = sentence_info
|
text, speaker_name, style_wav, language_name = sentence_info
|
||||||
elif len(sentence_info) == 5:
|
elif len(sentence_info) == 5:
|
||||||
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
|
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
|
||||||
|
elif len(sentence_info) == 6:
|
||||||
|
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = sentence_info
|
||||||
else:
|
else:
|
||||||
text = sentence_info
|
text = sentence_info
|
||||||
|
|
||||||
|
if style_wav and style_speaker_name is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f" [!] You must to provide the style_speaker_name for the style_wav !!"
|
||||||
|
)
|
||||||
|
|
||||||
# get speaker id/d_vector
|
# get speaker id/d_vector
|
||||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None
|
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager"):
|
||||||
if config.use_d_vector_file:
|
if config.use_d_vector_file:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
d_vector = self.speaker_manager.get_random_embeddings()
|
d_vector = self.speaker_manager.get_random_embeddings()
|
||||||
else:
|
else:
|
||||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||||
|
|
||||||
|
if style_wav is not None:
|
||||||
|
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
|
||||||
|
|
||||||
elif config.use_speaker_embedding:
|
elif config.use_speaker_embedding:
|
||||||
if speaker_name is None:
|
if speaker_name is None:
|
||||||
speaker_id = self.speaker_manager.get_random_id()
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
else:
|
else:
|
||||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||||
|
|
||||||
|
if style_wav is not None:
|
||||||
|
style_speaker_id = self.speaker_manager.ids[style_speaker_name]
|
||||||
|
|
||||||
# get language id
|
# get language id
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
language_id = self.language_manager.ids[language_name]
|
language_id = self.language_manager.ids[language_name]
|
||||||
|
@ -1740,6 +1769,8 @@ class Vits(BaseTTS):
|
||||||
"text": text,
|
"text": text,
|
||||||
"speaker_id": speaker_id,
|
"speaker_id": speaker_id,
|
||||||
"style_wav": style_wav,
|
"style_wav": style_wav,
|
||||||
|
"style_speaker_id": style_speaker_id,
|
||||||
|
"style_speaker_d_vector": style_speaker_d_vector,
|
||||||
"d_vector": d_vector,
|
"d_vector": d_vector,
|
||||||
"language_id": language_id,
|
"language_id": language_id,
|
||||||
"language_name": language_name,
|
"language_name": language_name,
|
||||||
|
@ -1773,6 +1804,8 @@ class Vits(BaseTTS):
|
||||||
language_id=aux_inputs["language_id"],
|
language_id=aux_inputs["language_id"],
|
||||||
emotion_embedding=aux_inputs["emotion_embedding"],
|
emotion_embedding=aux_inputs["emotion_embedding"],
|
||||||
emotion_id=aux_inputs["emotion_ids"],
|
emotion_id=aux_inputs["emotion_ids"],
|
||||||
|
style_speaker_id=aux_inputs["style_speaker_id"],
|
||||||
|
style_speaker_d_vector=aux_inputs["style_speaker_d_vector"],
|
||||||
use_griffin_lim=True,
|
use_griffin_lim=True,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
).values()
|
).values()
|
||||||
|
|
|
@ -31,6 +31,8 @@ def run_model_torch(
|
||||||
language_id: torch.Tensor = None,
|
language_id: torch.Tensor = None,
|
||||||
emotion_id: torch.Tensor = None,
|
emotion_id: torch.Tensor = None,
|
||||||
emotion_embedding: torch.Tensor = None,
|
emotion_embedding: torch.Tensor = None,
|
||||||
|
style_speaker_id: torch.Tensor = None,
|
||||||
|
style_speaker_d_vector: torch.Tensor = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Run a torch model for inference. It does not support batch inference.
|
"""Run a torch model for inference. It does not support batch inference.
|
||||||
|
|
||||||
|
@ -60,6 +62,8 @@ def run_model_torch(
|
||||||
"language_ids": language_id,
|
"language_ids": language_id,
|
||||||
"emotion_ids": emotion_id,
|
"emotion_ids": emotion_id,
|
||||||
"emotion_embeddings": emotion_embedding,
|
"emotion_embeddings": emotion_embedding,
|
||||||
|
"style_speaker_id": style_speaker_id,
|
||||||
|
"style_speaker_d_vector": style_speaker_d_vector,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -128,6 +132,8 @@ def synthesis(
|
||||||
language_id=None,
|
language_id=None,
|
||||||
emotion_id=None,
|
emotion_id=None,
|
||||||
emotion_embedding=None,
|
emotion_embedding=None,
|
||||||
|
style_speaker_id=None,
|
||||||
|
style_speaker_d_vector=None,
|
||||||
):
|
):
|
||||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||||
the vocoder model.
|
the vocoder model.
|
||||||
|
@ -205,6 +211,13 @@ def synthesis(
|
||||||
if emotion_embedding is not None:
|
if emotion_embedding is not None:
|
||||||
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
|
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
|
||||||
|
|
||||||
|
if style_speaker_id is not None:
|
||||||
|
style_speaker_id = id_to_torch(style_speaker_id, cuda=use_cuda)
|
||||||
|
|
||||||
|
if style_speaker_d_vector is not None:
|
||||||
|
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
|
||||||
|
|
||||||
|
|
||||||
if not isinstance(style_feature, dict):
|
if not isinstance(style_feature, dict):
|
||||||
# GST or Capacitron style mel
|
# GST or Capacitron style mel
|
||||||
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
||||||
|
@ -229,6 +242,8 @@ def synthesis(
|
||||||
language_id=language_id,
|
language_id=language_id,
|
||||||
emotion_id=emotion_id,
|
emotion_id=emotion_id,
|
||||||
emotion_embedding=emotion_embedding,
|
emotion_embedding=emotion_embedding,
|
||||||
|
style_speaker_id=style_speaker_id,
|
||||||
|
style_speaker_d_vector=style_speaker_d_vector,
|
||||||
)
|
)
|
||||||
model_outputs = outputs["model_outputs"]
|
model_outputs = outputs["model_outputs"]
|
||||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||||
|
|
|
@ -216,6 +216,7 @@ class Synthesizer(object):
|
||||||
emotion_name=None,
|
emotion_name=None,
|
||||||
source_emotion=None,
|
source_emotion=None,
|
||||||
target_emotion=None,
|
target_emotion=None,
|
||||||
|
style_speaker_name=None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
|
@ -247,6 +248,8 @@ class Synthesizer(object):
|
||||||
# handle multi-speaker
|
# handle multi-speaker
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
speaker_id = None
|
speaker_id = None
|
||||||
|
style_speaker_id = None
|
||||||
|
style_speaker_embedding = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
||||||
if speaker_name and isinstance(speaker_name, str):
|
if speaker_name and isinstance(speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
|
@ -255,10 +258,20 @@ class Synthesizer(object):
|
||||||
speaker_name, num_samples=None, randomize=False
|
speaker_name, num_samples=None, randomize=False
|
||||||
)
|
)
|
||||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
|
|
||||||
|
if style_speaker_name is not None:
|
||||||
|
style_speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||||
|
style_speaker_name, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
|
style_speaker_embedding = np.array(style_speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
|
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
|
||||||
|
|
||||||
|
if style_speaker_name is not None:
|
||||||
|
style_speaker_id = self.tts_model.speaker_manager.ids[style_speaker_name]
|
||||||
|
|
||||||
elif not speaker_name and not speaker_wav:
|
elif not speaker_name and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-speaker model. "
|
" [!] Look like you use a multi-speaker model. "
|
||||||
|
@ -327,6 +340,9 @@ class Synthesizer(object):
|
||||||
if speaker_wav is not None:
|
if speaker_wav is not None:
|
||||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||||
|
|
||||||
|
if style_wav is not None and style_speaker_name is None:
|
||||||
|
style_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(style_wav)
|
||||||
|
|
||||||
use_gl = self.vocoder_model is None
|
use_gl = self.vocoder_model is None
|
||||||
|
|
||||||
if not reference_wav:
|
if not reference_wav:
|
||||||
|
@ -340,6 +356,8 @@ class Synthesizer(object):
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
style_text=style_text,
|
style_text=style_text,
|
||||||
|
style_speaker_id=style_speaker_id,
|
||||||
|
style_speaker_d_vector=style_speaker_embedding,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
d_vector=speaker_embedding,
|
d_vector=speaker_embedding,
|
||||||
language_id=language_id,
|
language_id=language_id,
|
||||||
|
|
|
@ -26,7 +26,7 @@ config = VitsConfig(
|
||||||
print_step=1,
|
print_step=1,
|
||||||
print_eval=True,
|
print_eval=True,
|
||||||
test_sentences=[
|
test_sentences=[
|
||||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# set audio config
|
# set audio config
|
||||||
|
@ -46,7 +46,7 @@ config.model_args.prosody_embedding_dim = 64
|
||||||
# active classifier
|
# active classifier
|
||||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||||
config.model_args.use_prosody_enc_emo_classifier = False
|
config.model_args.use_prosody_enc_emo_classifier = False
|
||||||
config.model_args.use_text_enc_emo_classifier = True
|
config.model_args.use_text_enc_emo_classifier = False
|
||||||
config.model_args.use_prosody_encoder_z_p_input = True
|
config.model_args.use_prosody_encoder_z_p_input = True
|
||||||
|
|
||||||
config.model_args.prosody_encoder_type = "vae"
|
config.model_args.prosody_encoder_type = "vae"
|
||||||
|
@ -75,11 +75,13 @@ continue_config_path = os.path.join(continue_path, "config.json")
|
||||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
speaker_id = "ljspeech-1"
|
speaker_id = "ljspeech-1"
|
||||||
|
style_speaker_name = "ljspeech-2"
|
||||||
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
||||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||||
|
|
||||||
|
|
||||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path}"
|
print("Testing inference !")
|
||||||
|
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path} --style_speaker_name {style_speaker_name}"
|
||||||
run_cli(inference_command)
|
run_cli(inference_command)
|
||||||
|
|
||||||
# restore the model and continue training for one more epoch
|
# restore the model and continue training for one more epoch
|
||||||
|
|
Loading…
Reference in New Issue