mirror of https://github.com/coqui-ai/TTS.git
Add parse_key in set_ids_from_data
This commit is contained in:
parent
464775dbaf
commit
0e258d1784
|
@ -69,10 +69,10 @@ class LanguageManager(BaseIDManager):
|
|||
self.ids = self.parse_language_ids_from_config(c)
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_ids_from_data(self, items: List) -> Any:
|
||||
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_ids_to_file(self, file_path: str) -> None:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -34,14 +34,13 @@ class BaseIDManager:
|
|||
with fsspec.open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
||||
def set_ids_from_data(self, items: List) -> None:
|
||||
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
||||
"""Set IDs from data samples.
|
||||
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_tts_samples()`.
|
||||
"""
|
||||
self.ids, _ = self.parse_ids_from_data(items)
|
||||
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||
|
||||
def load_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set IDs from a file.
|
||||
|
@ -73,9 +72,18 @@ class BaseIDManager:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
raise NotImplementedError
|
||||
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
|
||||
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_tts_samples()`.
|
||||
parse_key (str): The key to being used to parse the data.
|
||||
Returns:
|
||||
Tuple[Dict]: speaker IDs.
|
||||
"""
|
||||
classes = sorted({item[parse_key] for item in items})
|
||||
ids = {name: i for i, name in enumerate(classes)}
|
||||
return ids
|
||||
|
||||
class EmbeddingManager(BaseIDManager):
|
||||
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||
|
@ -273,7 +281,3 @@ class EmbeddingManager(BaseIDManager):
|
|||
if self.use_cuda:
|
||||
feats = feats.cuda()
|
||||
return self.encoder.compute_embedding(feats)
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Any:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
@ -68,7 +68,7 @@ class SpeakerManager(EmbeddingManager):
|
|||
)
|
||||
|
||||
if data_items:
|
||||
self.ids, _ = self.parse_ids_from_data(data_items)
|
||||
self.set_ids_from_data(data_items, parse_key="speaker_name")
|
||||
|
||||
@property
|
||||
def num_speakers(self):
|
||||
|
@ -78,21 +78,6 @@ class SpeakerManager(EmbeddingManager):
|
|||
def speaker_names(self):
|
||||
return list(self.ids.keys())
|
||||
|
||||
@staticmethod
|
||||
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_tts_samples()`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
||||
"""
|
||||
speakers = sorted({item["speaker_name"] for item in items})
|
||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
||||
def get_speakers(self) -> List:
|
||||
return self.ids
|
||||
|
||||
|
@ -180,7 +165,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
speaker_manager = SpeakerManager()
|
||||
if c.use_speaker_embedding:
|
||||
if data is not None:
|
||||
speaker_manager.set_ids_from_data(data)
|
||||
speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
|
||||
if restore_path:
|
||||
speakers_file = _set_file_path(restore_path)
|
||||
# restoring speaker manager from a previous run.
|
||||
|
|
|
@ -115,7 +115,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
language_manager = LanguageManager(config=config)
|
||||
|
|
|
@ -76,7 +76,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
|
@ -77,7 +77,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -83,7 +83,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -83,7 +83,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
|
||||
# init model
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||
|
|
|
@ -84,7 +84,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# init speaker manager for multi-speaker training
|
||||
# it maps speaker-id to speaker-name in the model and data-loader
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
||||
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
|
|
Loading…
Reference in New Issue