From 121e9ed6851d4d69b5e8f144dd2419d15b5cf8e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 11 May 2022 11:31:17 +0200 Subject: [PATCH] Pass use_cuda to init_encoder --- TTS/tts/utils/managers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 85ed53cc..7c22ac88 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -216,17 +216,19 @@ class EmbeddingManager(BaseIDManager): def get_clips(self) -> List: return sorted(self.embeddings.keys()) - def init_encoder(self, model_path: str, config_path: str) -> None: + def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: """Initialize a speaker encoder model. Args: model_path (str): Model file path. config_path (str): Model config file path. + use_cuda (bool, optional): Use CUDA. Defaults to False. """ + self.use_cuda = use_cuda self.encoder_config = load_config(config_path) self.encoder = setup_encoder_model(self.encoder_config) self.encoder_criterion = self.encoder.load_checkpoint( - self.encoder_config, model_path, eval=True, use_cuda=self.use_cuda + self.encoder_config, model_path, eval=True, use_cuda=use_cuda ) self.encoder_ap = AudioProcessor(**self.encoder_config.audio)