diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d1755b47..ae607c47 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -402,6 +402,7 @@ class Vits(BaseTTS): # speaker embedding if self.num_speakers > 1 and sid is not None: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -638,7 +639,7 @@ class Vits(BaseTTS): return self._log(ap, batch, outputs, "eval") @torch.no_grad() - def test_run(self, ap) -> Tuple[Dict, Dict]: + def test_run(self, ap, eval_loader=None) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. @@ -650,16 +651,13 @@ class Vits(BaseTTS): test_audios = {} test_figures = {} test_sentences = self.config.test_sentences - aux_inputs = { - "speaker_id": None - if not self.config.use_speaker_embedding - else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), - "d_vector": None - if not self.config.use_d_vector_file - else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1), - "style_wav": None, - } + if hasattr(self, "speaker_manager"): + aux_inputs = self.speaker_manager.get_random_speaker_aux_input() + else: + aux_inputs = self.get_aux_input() + for idx, sen in enumerate(test_sentences): + wav, alignment, _, _ = synthesis( self, sen, diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index 13696a20..ae001155 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -209,6 +209,14 @@ class SpeakerManager: d_vectors = np.stack(d_vectors[:num_samples]).mean(0) return d_vectors + def get_random_speaker_aux_input(self) -> Dict: + if self.d_vectors: + return {"speaker_id": None, "style_wav": None, "d_vector": self.d_vectors[random.choices(list(self.d_vectors.keys()))[0]]["embedding"]} + elif self.speaker_ids: + return {"speaker_id": self.speaker_ids[random.choices(list(self.speaker_ids.keys()))[0]], "style_wav": None, "d_vector": None} + else: + return {"speaker_id": None, "style_wav": None, "d_vector": None} + def get_speakers(self) -> List: return self.speaker_ids