Add parse_key in set_ids_from_data

This commit is contained in:
Edresson Casanova 2022-03-14 13:53:46 +00:00
parent 464775dbaf
commit 0e258d1784
12 changed files with 28 additions and 39 deletions

View File

@ -69,10 +69,10 @@ class LanguageManager(BaseIDManager):
self.ids = self.parse_language_ids_from_config(c)
@staticmethod
def parse_ids_from_data(items: list) -> Any:
def parse_ids_from_data(items: List, parse_key: str) -> Any:
raise NotImplementedError
def set_ids_from_data(self, items: List) -> Any:
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError
def save_ids_to_file(self, file_path: str) -> None:

View File

@ -1,6 +1,6 @@
import json
import random
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Tuple, Union
import fsspec
import numpy as np
@ -34,14 +34,13 @@ class BaseIDManager:
with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4)
def set_ids_from_data(self, items: List) -> None:
def set_ids_from_data(self, items: List, parse_key: str) -> None:
"""Set IDs from data samples.
Args:
items (List): Data sampled returned by `load_tts_samples()`.
"""
self.ids, _ = self.parse_ids_from_data(items)
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
def load_ids_from_file(self, file_path: str) -> None:
"""Set IDs from a file.
@ -73,9 +72,18 @@ class BaseIDManager:
return None
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
"""Parse IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
parse_key (str): The key to being used to parse the data.
Returns:
Tuple[Dict]: speaker IDs.
"""
classes = sorted({item[parse_key] for item in items})
ids = {name: i for i, name in enumerate(classes)}
return ids
class EmbeddingManager(BaseIDManager):
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
@ -273,7 +281,3 @@ class EmbeddingManager(BaseIDManager):
if self.use_cuda:
feats = feats.cuda()
return self.encoder.compute_embedding(feats)
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError

View File

@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Union
import fsspec
import numpy as np
@ -68,7 +68,7 @@ class SpeakerManager(EmbeddingManager):
)
if data_items:
self.ids, _ = self.parse_ids_from_data(data_items)
self.set_ids_from_data(data_items, parse_key="speaker_name")
@property
def num_speakers(self):
@ -78,21 +78,6 @@ class SpeakerManager(EmbeddingManager):
def speaker_names(self):
return list(self.ids.keys())
@staticmethod
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def get_speakers(self) -> List:
return self.ids
@ -180,7 +165,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager = SpeakerManager()
if c.use_speaker_embedding:
if data is not None:
speaker_manager.set_ids_from_data(data)
speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
if restore_path:
speakers_file = _set_file_path(restore_path)
# restoring speaker manager from a previous run.

View File

@ -115,7 +115,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config)

View File

@ -76,7 +76,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -74,7 +74,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model

View File

@ -77,7 +77,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron(config, ap, tokenizer, speaker_manager)

View File

@ -83,7 +83,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -83,7 +83,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model
model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -84,7 +84,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples)
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers
# init model