mirror of https://github.com/coqui-ai/TTS.git
Add parse_key in set_ids_from_data
This commit is contained in:
parent
88e0cfa5a0
commit
2bc2685ff9
|
@ -69,10 +69,10 @@ class LanguageManager(BaseIDManager):
|
||||||
self.ids = self.parse_language_ids_from_config(c)
|
self.ids = self.parse_language_ids_from_config(c)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_ids_from_data(items: list) -> Any:
|
def parse_ids_from_data(items: List, parse_key: str) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_ids_from_data(self, items: List) -> Any:
|
def set_ids_from_data(self, items: List, parse_key: str) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_ids_to_file(self, file_path: str) -> None:
|
def save_ids_to_file(self, file_path: str) -> None:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -34,14 +34,13 @@ class BaseIDManager:
|
||||||
with fsspec.open(json_file_path, "w") as f:
|
with fsspec.open(json_file_path, "w") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
|
def set_ids_from_data(self, items: List, parse_key: str) -> None:
|
||||||
def set_ids_from_data(self, items: List) -> None:
|
|
||||||
"""Set IDs from data samples.
|
"""Set IDs from data samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items (List): Data sampled returned by `load_tts_samples()`.
|
items (List): Data sampled returned by `load_tts_samples()`.
|
||||||
"""
|
"""
|
||||||
self.ids, _ = self.parse_ids_from_data(items)
|
self.ids = self.parse_ids_from_data(items, parse_key=parse_key)
|
||||||
|
|
||||||
def load_ids_from_file(self, file_path: str) -> None:
|
def load_ids_from_file(self, file_path: str) -> None:
|
||||||
"""Set IDs from a file.
|
"""Set IDs from a file.
|
||||||
|
@ -73,9 +72,18 @@ class BaseIDManager:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_ids_from_data(items: list) -> Any:
|
def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]:
|
||||||
raise NotImplementedError
|
"""Parse IDs from data samples retured by `load_tts_samples()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (list): Data sampled returned by `load_tts_samples()`.
|
||||||
|
parse_key (str): The key to being used to parse the data.
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict]: speaker IDs.
|
||||||
|
"""
|
||||||
|
classes = sorted({item[parse_key] for item in items})
|
||||||
|
ids = {name: i for i, name in enumerate(classes)}
|
||||||
|
return ids
|
||||||
|
|
||||||
class EmbeddingManager(BaseIDManager):
|
class EmbeddingManager(BaseIDManager):
|
||||||
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
""" Base `Embedding` Manager class. Every new `Embedding` manager must inherit this.
|
||||||
|
@ -273,7 +281,3 @@ class EmbeddingManager(BaseIDManager):
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
feats = feats.cuda()
|
feats = feats.cuda()
|
||||||
return self.encoder.compute_embedding(feats)
|
return self.encoder.compute_embedding(feats)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_ids_from_data(items: list) -> Any:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -68,7 +68,7 @@ class SpeakerManager(EmbeddingManager):
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_items:
|
if data_items:
|
||||||
self.ids, _ = self.parse_ids_from_data(data_items)
|
self.set_ids_from_data(data_items, parse_key="speaker_name")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_speakers(self):
|
def num_speakers(self):
|
||||||
|
@ -78,21 +78,6 @@ class SpeakerManager(EmbeddingManager):
|
||||||
def speaker_names(self):
|
def speaker_names(self):
|
||||||
return list(self.ids.keys())
|
return list(self.ids.keys())
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_ids_from_data(items: list) -> Tuple[Dict, int]:
|
|
||||||
"""Parse speaker IDs from data samples retured by `load_tts_samples()`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items (list): Data sampled returned by `load_tts_samples()`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
|
||||||
"""
|
|
||||||
speakers = sorted({item["speaker_name"] for item in items})
|
|
||||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
|
||||||
num_speakers = len(speaker_ids)
|
|
||||||
return speaker_ids, num_speakers
|
|
||||||
|
|
||||||
def get_speakers(self) -> List:
|
def get_speakers(self) -> List:
|
||||||
return self.ids
|
return self.ids
|
||||||
|
|
||||||
|
@ -180,7 +165,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
if data is not None:
|
if data is not None:
|
||||||
speaker_manager.set_ids_from_data(data)
|
speaker_manager.set_ids_from_data(data, parse_key="speaker_name")
|
||||||
if restore_path:
|
if restore_path:
|
||||||
speakers_file = _set_file_path(restore_path)
|
speakers_file = _set_file_path(restore_path)
|
||||||
# restoring speaker manager from a previous run.
|
# restoring speaker manager from a previous run.
|
||||||
|
|
|
@ -119,7 +119,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
language_manager = LanguageManager(config=config)
|
language_manager = LanguageManager(config=config)
|
||||||
|
|
|
@ -81,7 +81,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.num_speakers = speaker_manager.num_speakers
|
config.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
|
|
|
@ -79,7 +79,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
|
|
|
@ -82,7 +82,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
model = Tacotron(config, ap, tokenizer, speaker_manager)
|
model = Tacotron(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
|
@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
|
@ -88,7 +88,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
# it mainly handles speaker-id to speaker-name for the model and the data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
model = Tacotron2(config, ap, tokenizer, speaker_manager)
|
||||||
|
|
|
@ -89,7 +89,7 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
# init speaker manager for multi-speaker training
|
# init speaker manager for multi-speaker training
|
||||||
# it maps speaker-id to speaker-name in the model and data-loader
|
# it maps speaker-id to speaker-name in the model and data-loader
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
speaker_manager.set_ids_from_data(train_samples + eval_samples)
|
speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
|
||||||
config.model_args.num_speakers = speaker_manager.num_speakers
|
config.model_args.num_speakers = speaker_manager.num_speakers
|
||||||
|
|
||||||
# init model
|
# init model
|
||||||
|
|
Loading…
Reference in New Issue