diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 1a49f0b0..05161a66 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -122,6 +122,7 @@ class Synthesizer(object): self.tts_model.cuda() if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.use_cuda = use_cuda self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config) def _set_speaker_encoder_paths_from_tts_config(self):