Add parse_key in set_ids_from_data

This commit is contained in:
Edresson Casanova 2022-03-14 13:53:46 +00:00
parent 88e0cfa5a0
commit 2bc2685ff9
12 changed files with 28 additions and 39 deletions

View File

@ -69,10 +69,10 @@ class LanguageManager(BaseIDManager):
self.ids = self.parse_language_ids_from_config(c) self.ids = self.parse_language_ids_from_config(c)
@staticmethod @staticmethod
def parse_ids_from_data(items: list) -> Any: def parse_ids_from_data(items: List, parse_key: str) -> Any:
raise NotImplementedError raise NotImplementedError
def set_ids_from_data(self, items: List) -> Any: def set_ids_from_data(self, items: List, parse_key: str) -> Any:
raise NotImplementedError raise NotImplementedError
def save_ids_to_file(self, file_path: str) -> None: def save_ids_to_file(self, file_path: str) -> None:

View File

@ -1,6 +1,6 @@
import json import json
import random import random
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Tuple, Union
import fsspec import fsspec
import numpy as np import numpy as np
@ -34,14 +34,13 @@ class BaseIDManager:
with fsspec.open(json_file_path, "w") as f: with fsspec.open(json_file_path, "w") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
def set_ids_from_data(self, items: List, parse_key: str) -> None:
def set_ids_from_data(self, items: List) -> None:
"""Set IDs from data samples. """Set IDs from data samples.
Args: Args:
items (List): Data sampled returned by `load_tts_samples()`. items (List): Data sampled returned by `load_tts_samples()`.
""" """
self.ids, _ = self.parse_ids_from_data(items) self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
def load_ids_from_file(self, file_path: str) -> None: def load_ids_from_file(self, file_path: str) -> None:
"""Set IDs from a file. """Set IDs from a file.
@ -73,9 +72,18 @@ class BaseIDManager:
return None return None
@staticmethod @staticmethod
def parse_ids_from_data(items: list) -> Any: def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
raise NotImplementedError """Parse IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
parse_key (str): The key to being used to parse the data.
Returns:
Tuple[Dict]: speaker IDs.
"""
classes = sorted({item[parse_key] for item in items})
ids = {name: i for i, name in enumerate(classes)}
return ids
class EmbeddingManager(BaseIDManager): class EmbeddingManager(BaseIDManager):
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. """ Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
@ -273,7 +281,3 @@ class EmbeddingManager(BaseIDManager):
if self.use_cuda: if self.use_cuda:
feats = feats.cuda() feats = feats.cuda()
return self.encoder.compute_embedding(feats) return self.encoder.compute_embedding(feats)
@staticmethod
def parse_ids_from_data(items: list) -> Any:
raise NotImplementedError

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Union
import fsspec import fsspec
import numpy as np import numpy as np
@ -68,7 +68,7 @@ class SpeakerManager(EmbeddingManager):
) )
if data_items: if data_items:
self.ids, _ = self.parse_ids_from_data(data_items) self.set_ids_from_data(data_items, parse_key="speaker_name")
@property @property
def num_speakers(self): def num_speakers(self):
@ -78,21 +78,6 @@ class SpeakerManager(EmbeddingManager):
def speaker_names(self): def speaker_names(self):
return list(self.ids.keys()) return list(self.ids.keys())
@staticmethod
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
Args:
items (list): Data sampled returned by `load_tts_samples()`.
Returns:
Tuple[Dict, int]: speaker IDs and number of speakers.
"""
speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids)
return speaker_ids, num_speakers
def get_speakers(self) -> List: def get_speakers(self) -> List:
return self.ids return self.ids
@ -180,7 +165,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
if c.use_speaker_embedding: if c.use_speaker_embedding:
if data is not None: if data is not None:
speaker_manager.set_ids_from_data(data) speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
if restore_path: if restore_path:
speakers_file = _set_file_path(restore_path) speakers_file = _set_file_path(restore_path)
# restoring speaker manager from a previous run. # restoring speaker manager from a previous run.

View File

@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers config.model_args.num_speakers = speaker_manager.num_speakers
language_manager = LanguageManager(config=config) language_manager = LanguageManager(config=config)

View File

@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers config.model_args.num_speakers = speaker_manager.num_speakers
# init model # init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers config.model_args.num_speakers = speaker_manager.num_speakers
# init model # init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.num_speakers = speaker_manager.num_speakers config.num_speakers = speaker_manager.num_speakers
# init model # init model

View File

@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers config.model_args.num_speakers = speaker_manager.num_speakers
# init model # init model

View File

@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader # it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model # init model
model = Tacotron(config, ap, tokenizer, speaker_manager) model = Tacotron(config, ap, tokenizer, speaker_manager)

View File

@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader # it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model # init model
model = Tacotron2(config, ap, tokenizer, speaker_manager) model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it mainly handles speaker-id to speaker-name for the model and the data-loader # it mainly handles speaker-id to speaker-name for the model and the data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
# init model # init model
model = Tacotron2(config, ap, tokenizer, speaker_manager) model = Tacotron2(config, ap, tokenizer, speaker_manager)

View File

@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples(
# init speaker manager for multi-speaker training # init speaker manager for multi-speaker training
# it maps speaker-id to speaker-name in the model and data-loader # it maps speaker-id to speaker-name in the model and data-loader
speaker_manager = SpeakerManager() speaker_manager = SpeakerManager()
speaker_manager.set_ids_from_data(train_samples + eval_samples) speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
config.model_args.num_speakers = speaker_manager.num_speakers config.model_args.num_speakers = speaker_manager.num_speakers
# init model # init model