mirror of https://github.com/coqui-ai/TTS.git
Add multilingual inference support
This commit is contained in:
parent
d0e3647db6
commit
8e83a212fa
|
@ -130,13 +130,13 @@ class VitsConfig(BaseTTSConfig):
|
|||
add_blank: bool = True
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
test_sentences: List[List] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
|
||||
["Be a voice, not an echo."],
|
||||
["I'm sorry Dave. I'm afraid I can't do that."],
|
||||
["This cake is great. It's so delicious and moist."],
|
||||
["Prior to November 22, 1963."],
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -107,6 +107,51 @@ class BaseTTS(BaseModel):
|
|||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||
|
||||
def get_aux_input_from_test_setences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if hasattr(self, "speaker_manager") and config.use_speaker_embedding:
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
|
||||
else:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
|
|
|
@ -399,8 +399,14 @@ class Vits(BaseTTS):
|
|||
sid = sid.unsqueeze_(0)
|
||||
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
||||
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
||||
if g.ndim == 2:
|
||||
g = g.unsqueeze_(0)
|
||||
|
||||
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
||||
lid = aux_input["language_ids"]
|
||||
if lid.ndim == 0:
|
||||
lid = lid.unsqueeze_(0)
|
||||
|
||||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
|
@ -437,9 +443,8 @@ class Vits(BaseTTS):
|
|||
"""
|
||||
outputs = {}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1 and sid is not None and not self.use_d_vector:
|
||||
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# language embedding
|
||||
|
@ -521,11 +526,11 @@ class Vits(BaseTTS):
|
|||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
|
||||
# speaker embedding
|
||||
if self.num_speakers > 0 and sid:
|
||||
if self.args.use_speaker_embedding and sid is not None and not self.use_d_vector:
|
||||
g = self.emb_g(sid).unsqueeze(-1)
|
||||
|
||||
# language embedding
|
||||
if self.args.use_language_embedding:
|
||||
if self.args.use_language_embedding and lid is not None:
|
||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||
|
@ -713,29 +718,29 @@ class Vits(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
if hasattr(self, "speaker_manager"):
|
||||
aux_inputs = self.speaker_manager.get_random_speaker_aux_input()
|
||||
else:
|
||||
aux_inputs = self.get_aux_input()
|
||||
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
for idx, s_info in enumerate(test_sentences):
|
||||
try:
|
||||
aux_inputs = self.get_aux_input_from_test_setences(s_info)
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
aux_inputs["text"],
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
language_id=aux_inputs["language_id"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
return test_figures, test_audios
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
|
@ -832,3 +837,5 @@ class Vits(BaseTTS):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
|
||||
|
|
|
@ -193,6 +193,20 @@ class SpeakerManager:
|
|||
"""
|
||||
return [x["embedding"] for x in self.d_vectors.values() if x["name"] == speaker_idx]
|
||||
|
||||
def get_d_vector_by_speaker(self, speaker_idx: str) -> np.ndarray:
|
||||
"""Get a d_vector of a speaker.
|
||||
|
||||
Args:
|
||||
speaker_idx (str): Target speaker ID.
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
for x in self.d_vectors.values():
|
||||
if x["name"] == speaker_idx:
|
||||
return x["embedding"]
|
||||
return None
|
||||
|
||||
def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean d_vector of a speaker ID.
|
||||
|
||||
|
@ -215,14 +229,31 @@ class SpeakerManager:
|
|||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||
return d_vectors
|
||||
|
||||
def get_random_speaker_aux_input(self) -> Dict:
|
||||
if self.d_vectors:
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]}
|
||||
def get_random_speaker_id(self) -> Any:
|
||||
"""Get a random d_vector.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.speaker_ids:
|
||||
return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None}
|
||||
return self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]]
|
||||
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||
return None
|
||||
|
||||
def get_random_d_vector(self) -> Any:
|
||||
"""Get a random D ID.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
np.ndarray: d_vector.
|
||||
"""
|
||||
if self.d_vectors:
|
||||
return self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]
|
||||
|
||||
return None
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.speaker_ids
|
||||
|
|
|
@ -71,6 +71,7 @@ def run_model_torch(
|
|||
speaker_id: int = None,
|
||||
style_mel: torch.Tensor = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -96,6 +97,7 @@ def run_model_torch(
|
|||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
"language_ids": language_id,
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
@ -160,13 +162,13 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def speaker_id_to_torch(speaker_id, cuda=False):
|
||||
if speaker_id is not None:
|
||||
speaker_id = np.asarray(speaker_id)
|
||||
speaker_id = torch.from_numpy(speaker_id)
|
||||
def id_to_torch(aux_id, cuda=False):
|
||||
if aux_id is not None:
|
||||
aux_id = np.asarray(aux_id)
|
||||
aux_id = torch.from_numpy(aux_id)
|
||||
if cuda:
|
||||
return speaker_id.cuda()
|
||||
return speaker_id
|
||||
return aux_id.cuda()
|
||||
return aux_id
|
||||
|
||||
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
|
@ -208,6 +210,7 @@ def synthesis(
|
|||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
language_id=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
|
@ -262,11 +265,14 @@ def synthesis(
|
|||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
|
||||
if language_id is not None:
|
||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
|
@ -278,7 +284,7 @@ def synthesis(
|
|||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
|
|
Loading…
Reference in New Issue