Select randomly a speaker from the speaker manager for the test setences

This commit is contained in:
Edresson 2021-08-12 21:52:12 -03:00 committed by Eren Gölge
parent 8310d19da8
commit 234a4aacb3
2 changed files with 16 additions and 10 deletions

View File

@ -402,6 +402,7 @@ class Vits(BaseTTS):
# speaker embedding
if self.num_speakers > 1 and sid is not None:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
@ -638,7 +639,7 @@ class Vits(BaseTTS):
return self._log(ap, batch, outputs, "eval")
@torch.no_grad()
def test_run(self, ap) -> Tuple[Dict, Dict]:
def test_run(self, ap, eval_loader=None) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
@ -650,16 +651,13 @@ class Vits(BaseTTS):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None,
}
if hasattr(self, "speaker_manager"):
aux_inputs = self.speaker_manager.get_random_speaker_aux_input()
else:
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, _, _ = synthesis(
self,
sen,

View File

@ -209,6 +209,14 @@ class SpeakerManager:
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
return d_vectors
def get_random_speaker_aux_input(self) -> Dict:
if self.d_vectors:
return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]}
elif self.speaker_ids:
return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None}
else:
return {"speaker_id": None, "style_wav": None, "d_vector": None}
def get_speakers(self) -> List:
return self.speaker_ids