Implement multilingual dataloader support

This commit is contained in:
Edresson 2021-08-13 19:58:56 -03:00 committed by Eren Gölge
parent 5f1c18187f
commit f996afedb0
6 changed files with 76 additions and 9 deletions

View File

@ -199,6 +199,7 @@ class BaseDatasetConfig(Coqpit):
path: str = ""
meta_file_train: str = ""
ununsed_speakers: List[str] = None
language: str = ""
meta_file_val: str = ""
meta_file_attn_mask: str = ""
@ -335,6 +336,8 @@ class BaseTrainingConfig(Coqpit):
num_loader_workers: int = 0
num_eval_loader_workers: int = 0
use_noise_augment: bool = False
use_language_weighted_sampler: bool = False
# paths
output_path: str = None
# distributed

View File

@ -260,6 +260,20 @@ class Trainer:
else:
self.run_get_model(self.config, get_model)
if hasattr(self.model, "init_multilingual"):
self.model.init_multilingual(self.config, self.data_train + self.data_eval)
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
if hasattr(self.config, "model_args"):
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
else:
self.config.num_languages = self.model.language_manager.num_languages
# update config file
copy_model_files(self.config, self.output_path, None)
# setup criterion
self.criterion = self.get_criterion(self.model)

View File

@ -68,16 +68,22 @@ def load_tts_samples(
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
ununsed_speakers = dataset["ununsed_speakers"]
language = dataset["language"]
# setup the right data processor
if formatter is None:
formatter = _get_formatter_by_name(name)
# load train set
meta_data_train = formatter(root_path, meta_file_train, ununsed_speakers=ununsed_speakers)
# TODO: remove the loops and pass language as a parameter to preprocessor for faster load
meta_data_train = [[*item, language] for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set
if eval_split:
if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ununsed_speakers=ununsed_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval]
else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval

View File

@ -37,6 +37,7 @@ class TTSDataset(Dataset):
enable_eos_bos: bool = False,
speaker_id_mapping: Dict = None,
d_vector_mapping: Dict = None,
language_id_mapping: Dict = None,
use_noise_augment: bool = False,
verbose: bool = False,
):
@ -122,6 +123,7 @@ class TTSDataset(Dataset):
self.enable_eos_bos = enable_eos_bos
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.language_id_mapping = language_id_mapping
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
@ -197,10 +199,10 @@ class TTSDataset(Dataset):
def load_data(self, idx):
item = self.items[idx]
if len(item) == 4:
text, wav_file, speaker_name, attn_file = item
if len(item) == 5:
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name = item
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
@ -218,7 +220,7 @@ class TTSDataset(Dataset):
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
language_name if language_name else self.phoneme_language,
self.custom_symbols,
self.characters,
self.add_blank,
@ -260,6 +262,7 @@ class TTSDataset(Dataset):
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
"language_name": language_name,
"wav_file_name": os.path.basename(wav_file),
}
return sample
@ -413,6 +416,14 @@ class TTSDataset(Dataset):
# convert list of dicts to dict of lists
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get language ids from language names
if self.language_id_mapping is not None:
language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing]
language_ids = [self.language_id_mapping[ln] for ln in language_names]
else:
language_ids = None
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
@ -466,6 +477,9 @@ class TTSDataset(Dataset):
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
@ -528,6 +542,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"language_ids": language_ids
}
raise TypeError(

View File

@ -13,6 +13,7 @@ from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit):
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.
Args:
config (Coqpit): Model configuration.
@ -122,6 +132,7 @@ class BaseTTS(BaseModel):
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -169,6 +180,7 @@ class BaseTTS(BaseModel):
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"language_ids": language_ids,
}
def get_data_loader(
@ -199,7 +211,12 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataset
if hasattr(self, "language_manager"):
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
else:
language_id_mapping = None
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
@ -223,6 +240,7 @@ class BaseTTS(BaseModel):
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
language_id_mapping=language_id_mapping,
)
# pre-compute phonemes
@ -267,8 +285,11 @@ class BaseTTS(BaseModel):
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
sampler = get_language_weighted_sampler(dataset.items)
print(" > Using Language weighted sampler")
# init dataloader
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -135,3 +135,11 @@ def phoneme_cleaners(text):
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text):
'''Pipeline for multilingual text'''
text = lowercase(text)
text = replace_symbols(text, lang=None)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text