mirror of https://github.com/coqui-ai/TTS.git
Select randomly a speaker from the speaker manager for the test setences
This commit is contained in:
parent
8310d19da8
commit
234a4aacb3
|
@ -402,6 +402,7 @@ class Vits(BaseTTS):
|
|||
# speaker embedding
|
||||
if self.num_speakers > 1 and sid is not None:
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||
|
||||
|
@ -638,7 +639,7 @@ class Vits(BaseTTS):
|
|||
return self._log(ap, batch, outputs, "eval")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
||||
def test_run(self, ap, eval_loader=None) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
@ -650,16 +651,13 @@ class Vits(BaseTTS):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
"d_vector": None
|
||||
if not self.config.use_d_vector_file
|
||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||
"style_wav": None,
|
||||
}
|
||||
if hasattr(self, "speaker_manager"):
|
||||
aux_inputs = self.speaker_manager.get_random_speaker_aux_input()
|
||||
else:
|
||||
aux_inputs = self.get_aux_input()
|
||||
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
|
||||
wav, alignment, _, _ = synthesis(
|
||||
self,
|
||||
sen,
|
||||
|
|
|
@ -209,6 +209,14 @@ class SpeakerManager:
|
|||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||
return d_vectors
|
||||
|
||||
def get_random_speaker_aux_input(self) -> Dict:
|
||||
if self.d_vectors:
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]}
|
||||
elif self.speaker_ids:
|
||||
return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None}
|
||||
else:
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.speaker_ids
|
||||
|
||||
|
|
Loading…
Reference in New Issue