mirror of https://github.com/coqui-ai/TTS.git
Add prosody encoder inference support
This commit is contained in:
parent
2568b722dd
commit
bd35371944
|
@ -179,6 +179,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
default=None,
|
||||
)
|
||||
parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None)
|
||||
parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None)
|
||||
parser.add_argument(
|
||||
"--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None
|
||||
)
|
||||
|
@ -325,6 +326,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
emotion_name=args.emotion_idx,
|
||||
style_speaker_name=args.style_speaker_name,
|
||||
)
|
||||
|
||||
# save the results
|
||||
|
|
|
@ -981,7 +981,7 @@ class Vits(BaseTTS):
|
|||
@staticmethod
|
||||
def _set_cond_input(aux_input: Dict):
|
||||
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
||||
sid, g, lid, eid, eg, pf = None, None, None, None, None, None
|
||||
sid, g, lid, eid, eg, pf, ssid, ssg = None, None, None, None, None, None, None, None
|
||||
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
||||
sid = aux_input["speaker_ids"]
|
||||
if sid.ndim == 0:
|
||||
|
@ -1010,7 +1010,18 @@ class Vits(BaseTTS):
|
|||
pf = aux_input["style_feature"]
|
||||
if pf.ndim == 2:
|
||||
pf = pf.unsqueeze_(0)
|
||||
return sid, g, lid, eid, eg, pf
|
||||
|
||||
if "style_speaker_id" in aux_input and aux_input["style_speaker_id"] is not None:
|
||||
ssid = aux_input["style_speaker_id"]
|
||||
if ssid.ndim == 0:
|
||||
ssid = ssid.unsqueeze_(0)
|
||||
|
||||
if "style_speaker_d_vector" in aux_input and aux_input["style_speaker_d_vector"] is not None:
|
||||
ssg = F.normalize(aux_input["style_speaker_d_vector"]).unsqueeze(-1)
|
||||
if ssg.ndim == 2:
|
||||
ssg = ssg.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid, eid, eg, pf, ssid, ssg
|
||||
|
||||
def _set_speaker_input(self, aux_input: Dict):
|
||||
d_vectors = aux_input.get("d_vectors", None)
|
||||
|
@ -1130,7 +1141,7 @@ class Vits(BaseTTS):
|
|||
- syn_cons_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
||||
"""
|
||||
outputs = {}
|
||||
sid, g, lid, eid, eg, _ = self._set_cond_input(aux_input)
|
||||
sid, g, lid, eid, eg, _, _, _ = self._set_cond_input(aux_input)
|
||||
# speaker embedding
|
||||
if self.args.use_speaker_embedding and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -1317,7 +1328,7 @@ class Vits(BaseTTS):
|
|||
- m_p: :math:`[B, C, T_dec]`
|
||||
- logs_p: :math:`[B, C, T_dec]`
|
||||
"""
|
||||
sid, g, lid, eid, eg, pf = self._set_cond_input(aux_input)
|
||||
sid, g, lid, eid, eg, pf, ssid, ssg = self._set_cond_input(aux_input)
|
||||
x_lengths = self._set_x_lengths(x, aux_input)
|
||||
|
||||
# speaker embedding
|
||||
|
@ -1336,13 +1347,17 @@ class Vits(BaseTTS):
|
|||
# prosody embedding
|
||||
pros_emb = None
|
||||
if self.args.use_prosody_encoder:
|
||||
# speaker embedding for the style speaker
|
||||
if self.args.use_speaker_embedding and ssid is not None:
|
||||
ssg = self.emb_g(ssid).unsqueeze(-1)
|
||||
|
||||
# extract posterior encoder feature
|
||||
pf_lengths = torch.tensor([pf.size(-1)]).to(pf.device)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=g)
|
||||
z_pro, _, _, z_pro_y_mask = self.posterior_encoder(pf, pf_lengths, g=ssg)
|
||||
if not self.args.use_prosody_encoder_z_p_input:
|
||||
pros_emb, _ = self.prosody_encoder(z_pro, pf_lengths)
|
||||
else:
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=g)
|
||||
z_p_inf = self.flow(z_pro, z_pro_y_mask, g=ssg)
|
||||
pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths)
|
||||
|
||||
pros_emb = pros_emb.transpose(1, 2)
|
||||
|
@ -1687,7 +1702,7 @@ class Vits(BaseTTS):
|
|||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name, emotion_name = None, None, None, None, None
|
||||
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
|
@ -1700,23 +1715,37 @@ class Vits(BaseTTS):
|
|||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
elif len(sentence_info) == 5:
|
||||
text, speaker_name, style_wav, language_name, emotion_name = sentence_info
|
||||
elif len(sentence_info) == 6:
|
||||
text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
if style_wav and style_speaker_name is None:
|
||||
raise RuntimeError(
|
||||
f" [!] You must to provide the style_speaker_name for the style_wav !!"
|
||||
)
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding = None, None, None, None, None
|
||||
speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_embeddings()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
||||
|
||||
if style_wav is not None:
|
||||
style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False)
|
||||
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.ids[speaker_name]
|
||||
|
||||
if style_wav is not None:
|
||||
style_speaker_id = self.speaker_manager.ids[style_speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.ids[language_name]
|
||||
|
@ -1740,6 +1769,8 @@ class Vits(BaseTTS):
|
|||
"text": text,
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"style_speaker_id": style_speaker_id,
|
||||
"style_speaker_d_vector": style_speaker_d_vector,
|
||||
"d_vector": d_vector,
|
||||
"language_id": language_id,
|
||||
"language_name": language_name,
|
||||
|
@ -1773,6 +1804,8 @@ class Vits(BaseTTS):
|
|||
language_id=aux_inputs["language_id"],
|
||||
emotion_embedding=aux_inputs["emotion_embedding"],
|
||||
emotion_id=aux_inputs["emotion_ids"],
|
||||
style_speaker_id=aux_inputs["style_speaker_id"],
|
||||
style_speaker_d_vector=aux_inputs["style_speaker_d_vector"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
|
|
@ -31,6 +31,8 @@ def run_model_torch(
|
|||
language_id: torch.Tensor = None,
|
||||
emotion_id: torch.Tensor = None,
|
||||
emotion_embedding: torch.Tensor = None,
|
||||
style_speaker_id: torch.Tensor = None,
|
||||
style_speaker_d_vector: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -60,6 +62,8 @@ def run_model_torch(
|
|||
"language_ids": language_id,
|
||||
"emotion_ids": emotion_id,
|
||||
"emotion_embeddings": emotion_embedding,
|
||||
"style_speaker_id": style_speaker_id,
|
||||
"style_speaker_d_vector": style_speaker_d_vector,
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
@ -128,6 +132,8 @@ def synthesis(
|
|||
language_id=None,
|
||||
emotion_id=None,
|
||||
emotion_embedding=None,
|
||||
style_speaker_id=None,
|
||||
style_speaker_d_vector=None,
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
the vocoder model.
|
||||
|
@ -205,6 +211,13 @@ def synthesis(
|
|||
if emotion_embedding is not None:
|
||||
emotion_embedding = embedding_to_torch(emotion_embedding, cuda=use_cuda)
|
||||
|
||||
if style_speaker_id is not None:
|
||||
style_speaker_id = id_to_torch(style_speaker_id, cuda=use_cuda)
|
||||
|
||||
if style_speaker_d_vector is not None:
|
||||
style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda)
|
||||
|
||||
|
||||
if not isinstance(style_feature, dict):
|
||||
# GST or Capacitron style mel
|
||||
style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda)
|
||||
|
@ -229,6 +242,8 @@ def synthesis(
|
|||
language_id=language_id,
|
||||
emotion_id=emotion_id,
|
||||
emotion_embedding=emotion_embedding,
|
||||
style_speaker_id=style_speaker_id,
|
||||
style_speaker_d_vector=style_speaker_d_vector,
|
||||
)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
|
|
|
@ -216,6 +216,7 @@ class Synthesizer(object):
|
|||
emotion_name=None,
|
||||
source_emotion=None,
|
||||
target_emotion=None,
|
||||
style_speaker_name=None,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
|
@ -247,6 +248,8 @@ class Synthesizer(object):
|
|||
# handle multi-speaker
|
||||
speaker_embedding = None
|
||||
speaker_id = None
|
||||
style_speaker_id = None
|
||||
style_speaker_embedding = None
|
||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"):
|
||||
if speaker_name and isinstance(speaker_name, str):
|
||||
if self.tts_config.use_d_vector_file:
|
||||
|
@ -255,10 +258,20 @@ class Synthesizer(object):
|
|||
speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
|
||||
if style_speaker_name is not None:
|
||||
style_speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
|
||||
style_speaker_name, num_samples=None, randomize=False
|
||||
)
|
||||
style_speaker_embedding = np.array(style_speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||
|
||||
else:
|
||||
# get speaker idx from the speaker name
|
||||
speaker_id = self.tts_model.speaker_manager.ids[speaker_name]
|
||||
|
||||
if style_speaker_name is not None:
|
||||
style_speaker_id = self.tts_model.speaker_manager.ids[style_speaker_name]
|
||||
|
||||
elif not speaker_name and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-speaker model. "
|
||||
|
@ -327,6 +340,9 @@ class Synthesizer(object):
|
|||
if speaker_wav is not None:
|
||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||
|
||||
if style_wav is not None and style_speaker_name is None:
|
||||
style_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(style_wav)
|
||||
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
if not reference_wav:
|
||||
|
@ -340,6 +356,8 @@ class Synthesizer(object):
|
|||
speaker_id=speaker_id,
|
||||
style_wav=style_wav,
|
||||
style_text=style_text,
|
||||
style_speaker_id=style_speaker_id,
|
||||
style_speaker_d_vector=style_speaker_embedding,
|
||||
use_griffin_lim=use_gl,
|
||||
d_vector=speaker_embedding,
|
||||
language_id=language_id,
|
||||
|
|
|
@ -26,7 +26,7 @@ config = VitsConfig(
|
|||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None],
|
||||
["Be a voice, not an echo.", "ljspeech-1", "tests/data/ljspeech/wavs/LJ001-0001.wav", None, None, "ljspeech-2"],
|
||||
],
|
||||
)
|
||||
# set audio config
|
||||
|
@ -46,7 +46,7 @@ config.model_args.prosody_embedding_dim = 64
|
|||
# active classifier
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_prosody_enc_emo_classifier = False
|
||||
config.model_args.use_text_enc_emo_classifier = True
|
||||
config.model_args.use_text_enc_emo_classifier = False
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
|
||||
config.model_args.prosody_encoder_type = "vae"
|
||||
|
@ -75,11 +75,13 @@ continue_config_path = os.path.join(continue_path, "config.json")
|
|||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
style_speaker_name = "ljspeech-2"
|
||||
style_wav_path = "tests/data/ljspeech/wavs/LJ001-0001.wav"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path}"
|
||||
print("Testing inference !")
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} --style_wav {style_wav_path} --style_speaker_name {style_speaker_name}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
|
|
Loading…
Reference in New Issue