mirror of https://github.com/coqui-ai/TTS.git
Select randomly a speaker from the speaker manager for the test setences
This commit is contained in:
parent
eb3e8affe1
commit
e0ad838066
|
@ -402,6 +402,7 @@ class Vits(BaseTTS):
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1 and sid is not None:
|
if self.num_speakers > 1 and sid is not None:
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
|
||||||
|
@ -638,7 +639,7 @@ class Vits(BaseTTS):
|
||||||
return self._log(ap, batch, outputs, "eval")
|
return self._log(ap, batch, outputs, "eval")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self, ap) -> Tuple[Dict, Dict]:
|
def test_run(self, ap, eval_loader=None) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
You can override this for a different behaviour.
|
You can override this for a different behaviour.
|
||||||
|
@ -650,16 +651,13 @@ class Vits(BaseTTS):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = {
|
if hasattr(self, "speaker_manager"):
|
||||||
"speaker_id": None
|
aux_inputs = self.speaker_manager.get_random_speaker_aux_input()
|
||||||
if not self.config.use_speaker_embedding
|
else:
|
||||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
aux_inputs = self.get_aux_input()
|
||||||
"d_vector": None
|
|
||||||
if not self.config.use_d_vector_file
|
|
||||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
|
||||||
"style_wav": None,
|
|
||||||
}
|
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
|
||||||
wav, alignment, _, _ = synthesis(
|
wav, alignment, _, _ = synthesis(
|
||||||
self,
|
self,
|
||||||
sen,
|
sen,
|
||||||
|
|
|
@ -209,6 +209,14 @@ class SpeakerManager:
|
||||||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||||
return d_vectors
|
return d_vectors
|
||||||
|
|
||||||
|
def get_random_speaker_aux_input(self) -> Dict:
|
||||||
|
if self.d_vectors:
|
||||||
|
return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]}
|
||||||
|
elif self.speaker_ids:
|
||||||
|
return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None}
|
||||||
|
else:
|
||||||
|
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||||
|
|
||||||
def get_speakers(self) -> List:
|
def get_speakers(self) -> List:
|
||||||
return self.speaker_ids
|
return self.speaker_ids
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue