mirror of https://github.com/coqui-ai/TTS.git
Fix compute embeddings issue
This commit is contained in:
parent
a6c8fea192
commit
f50819a5f6
|
@ -19,11 +19,28 @@ parser = argparse.ArgumentParser(
|
|||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
|
||||
<<<<<<< HEAD
|
||||
parser.add_argument("config_path", type=str, help="Path to model config file.")
|
||||
parser.add_argument("config_dataset_path", type=str, help="Path to dataset config file.")
|
||||
parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth")
|
||||
parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None)
|
||||
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
|
||||
=======
|
||||
parser.add_argument(
|
||||
"config_path",
|
||||
type=str,
|
||||
help="Path to model config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"config_dataset_path",
|
||||
type=str,
|
||||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("output_path", type=str, help="path for output .json file.")
|
||||
parser.add_argument("--old_file", type=str, help="Previous .json file, only compute for new audios.", default=None)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False)
|
||||
>>>>>>> Fix compute embeddings issue
|
||||
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
|
||||
parser.add_argument(
|
||||
"--use_predicted_label", type=bool, help="If True and predicted label is available with will use it.", default=False
|
||||
|
@ -49,6 +66,7 @@ encoder_manager = EmbeddingManager(
|
|||
use_cuda=use_cuda,
|
||||
)
|
||||
|
||||
print("Using CUDA?", args.use_cuda)
|
||||
class_name_key = encoder_manager.encoder_config.class_name_key
|
||||
|
||||
# compute speaker embeddings
|
||||
|
|
|
@ -132,7 +132,7 @@ class EmbeddingManager(BaseIDManager):
|
|||
self.load_embeddings_from_file(embedding_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_encoder(encoder_model_path, encoder_config_path, use_cuda)
|
||||
self.init_encoder(encoder_model_path, encoder_config_path, use_cuda=use_cuda)
|
||||
|
||||
@property
|
||||
def embedding_dim(self):
|
||||
|
|
Loading…
Reference in New Issue