diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 91d07257..07ece032 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -19,11 +19,28 @@ parser = argparse.ArgumentParser( formatter_class=RawTextHelpFormatter, ) parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") +<<<<<<< HEAD parser.add_argument("config_path", type=str, help="Path to model config file.") parser.add_argument("config_dataset_path", type=str, help="Path to dataset config file.") parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth") parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None) parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) +======= +parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", +) + +parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", +) +parser.add_argument("output_path", type=str, help="path for output .json file.") +parser.add_argument("--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None) +parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False) +>>>>>>> Fix compute embeddings issue parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False) parser.add_argument( "--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False @@ -49,6 +66,7 @@ encoder_manager = EmbeddingManager( use_cuda=use_cuda, ) +print("Using CUDA?", args.use_cuda) class_name_key = encoder_manager.encoder_config.class_name_key # compute speaker embeddings diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py index 0243d3b4..7396745f 100644 --- a/TTS/tts/utils/managers.py +++ b/TTS/tts/utils/managers.py @@ -132,7 +132,7 @@ class EmbeddingManager(BaseIDManager): self.load_embeddings_from_file(embedding_file_path) if encoder_model_path and encoder_config_path: - self.init_encoder(encoder_model_path, encoder_config_path, use_cuda) + self.init_encoder(encoder_model_path, encoder_config_path, use_cuda=use_cuda) @property def embedding_dim(self):