mirror of https://github.com/coqui-ai/TTS.git
Implement multilingual dataloader support
This commit is contained in:
parent
5f1c18187f
commit
f996afedb0
|
@ -199,6 +199,7 @@ class BaseDatasetConfig(Coqpit):
|
|||
path: str = ""
|
||||
meta_file_train: str = ""
|
||||
ununsed_speakers: List[str] = None
|
||||
language: str = ""
|
||||
meta_file_val: str = ""
|
||||
meta_file_attn_mask: str = ""
|
||||
|
||||
|
@ -335,6 +336,8 @@ class BaseTrainingConfig(Coqpit):
|
|||
num_loader_workers: int = 0
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
use_language_weighted_sampler: bool = False
|
||||
|
||||
# paths
|
||||
output_path: str = None
|
||||
# distributed
|
||||
|
|
|
@ -260,6 +260,20 @@ class Trainer:
|
|||
else:
|
||||
self.run_get_model(self.config, get_model)
|
||||
|
||||
if hasattr(self.model, "init_multilingual"):
|
||||
self.model.init_multilingual(self.config, self.data_train + self.data_eval)
|
||||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||
# save speakers json
|
||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
|
||||
if hasattr(self.config, "model_args"):
|
||||
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
||||
else:
|
||||
self.config.num_languages = self.model.language_manager.num_languages
|
||||
|
||||
# update config file
|
||||
copy_model_files(self.config, self.output_path, None)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
||||
|
|
|
@ -68,16 +68,22 @@ def load_tts_samples(
|
|||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
ununsed_speakers = dataset["ununsed_speakers"]
|
||||
language = dataset["language"]
|
||||
|
||||
# setup the right data processor
|
||||
if formatter is None:
|
||||
formatter = _get_formatter_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = formatter(root_path, meta_file_train, ununsed_speakers=ununsed_speakers)
|
||||
# TODO: remove the loops and pass language as a parameter to preprocessor for faster load
|
||||
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ununsed_speakers=ununsed_speakers)
|
||||
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
|
|
@ -37,6 +37,7 @@ class TTSDataset(Dataset):
|
|||
enable_eos_bos: bool = False,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
language_id_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
|
@ -122,6 +123,7 @@ class TTSDataset(Dataset):
|
|||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.language_id_mapping = language_id_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
|
@ -197,10 +199,10 @@ class TTSDataset(Dataset):
|
|||
def load_data(self, idx):
|
||||
item = self.items[idx]
|
||||
|
||||
if len(item) == 4:
|
||||
text, wav_file, speaker_name, attn_file = item
|
||||
if len(item) == 5:
|
||||
text, wav_file, speaker_name, language_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name = item
|
||||
text, wav_file, speaker_name, language_name = item
|
||||
attn = None
|
||||
raw_text = text
|
||||
|
||||
|
@ -218,7 +220,7 @@ class TTSDataset(Dataset):
|
|||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
language_name if language_name else self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
|
@ -260,6 +262,7 @@ class TTSDataset(Dataset):
|
|||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
"language_name": language_name,
|
||||
"wav_file_name": os.path.basename(wav_file),
|
||||
}
|
||||
return sample
|
||||
|
@ -413,6 +416,14 @@ class TTSDataset(Dataset):
|
|||
# convert list of dicts to dict of lists
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
|
||||
# get language ids from language names
|
||||
if self.language_id_mapping is not None:
|
||||
language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing]
|
||||
language_ids = [self.language_id_mapping[ln] for ln in language_names]
|
||||
else:
|
||||
language_ids = None
|
||||
# get pre-computed d-vectors
|
||||
if self.d_vector_mapping is not None:
|
||||
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||
|
@ -466,6 +477,9 @@ class TTSDataset(Dataset):
|
|||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
||||
if language_ids is not None:
|
||||
language_ids = torch.LongTensor(language_ids)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||
|
@ -528,6 +542,7 @@ class TTSDataset(Dataset):
|
|||
"waveform": wav_padded,
|
||||
"raw_text": batch["raw_text"],
|
||||
"pitch": pitch,
|
||||
"language_ids": language_ids
|
||||
}
|
||||
|
||||
raise TypeError(
|
||||
|
|
|
@ -13,6 +13,7 @@ from TTS.model import BaseModel
|
|||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
|
|||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit):
|
||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
@ -122,6 +132,7 @@ class BaseTTS(BaseModel):
|
|||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -169,6 +180,7 @@ class BaseTTS(BaseModel):
|
|||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"language_ids": language_ids,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
|
@ -199,7 +211,12 @@ class BaseTTS(BaseModel):
|
|||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataset
|
||||
if hasattr(self, "language_manager"):
|
||||
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
else:
|
||||
language_id_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
|
@ -223,6 +240,7 @@ class BaseTTS(BaseModel):
|
|||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
language_id_mapping=language_id_mapping,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
|
@ -267,8 +285,11 @@ class BaseTTS(BaseModel):
|
|||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
if sampler is None:
|
||||
if getattr(config, "use_language_weighted_sampler", False):
|
||||
sampler = get_language_weighted_sampler(dataset.items)
|
||||
print(" > Using Language weighted sampler")
|
||||
|
||||
# init dataloader
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
|
|
|
@ -135,3 +135,11 @@ def phoneme_cleaners(text):
|
|||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
def multilingual_cleaners(text):
|
||||
'''Pipeline for multilingual text'''
|
||||
text = lowercase(text)
|
||||
text = replace_symbols(text, lang=None)
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
Loading…
Reference in New Issue