mirror of https://github.com/coqui-ai/TTS.git
Implement multilingual dataloader support
This commit is contained in:
parent
c9f5838bb4
commit
829ee55b04
|
@ -199,6 +199,7 @@ class BaseDatasetConfig(Coqpit):
|
||||||
path: str = ""
|
path: str = ""
|
||||||
meta_file_train: str = ""
|
meta_file_train: str = ""
|
||||||
ununsed_speakers: List[str] = None
|
ununsed_speakers: List[str] = None
|
||||||
|
language: str = ""
|
||||||
meta_file_val: str = ""
|
meta_file_val: str = ""
|
||||||
meta_file_attn_mask: str = ""
|
meta_file_attn_mask: str = ""
|
||||||
|
|
||||||
|
@ -335,6 +336,8 @@ class BaseTrainingConfig(Coqpit):
|
||||||
num_loader_workers: int = 0
|
num_loader_workers: int = 0
|
||||||
num_eval_loader_workers: int = 0
|
num_eval_loader_workers: int = 0
|
||||||
use_noise_augment: bool = False
|
use_noise_augment: bool = False
|
||||||
|
use_language_weighted_sampler: bool = False
|
||||||
|
|
||||||
# paths
|
# paths
|
||||||
output_path: str = None
|
output_path: str = None
|
||||||
# distributed
|
# distributed
|
||||||
|
|
|
@ -260,6 +260,20 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.run_get_model(self.config, get_model)
|
self.run_get_model(self.config, get_model)
|
||||||
|
|
||||||
|
if hasattr(self.model, "init_multilingual"):
|
||||||
|
self.model.init_multilingual(self.config, self.data_train + self.data_eval)
|
||||||
|
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||||
|
# save speakers json
|
||||||
|
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||||
|
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
||||||
|
else:
|
||||||
|
self.config.num_languages = self.model.language_manager.num_languages
|
||||||
|
|
||||||
|
# update config file
|
||||||
|
copy_model_files(self.config, self.output_path, None)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
|
||||||
|
|
|
@ -68,16 +68,22 @@ def load_tts_samples(
|
||||||
meta_file_train = dataset["meta_file_train"]
|
meta_file_train = dataset["meta_file_train"]
|
||||||
meta_file_val = dataset["meta_file_val"]
|
meta_file_val = dataset["meta_file_val"]
|
||||||
ununsed_speakers = dataset["ununsed_speakers"]
|
ununsed_speakers = dataset["ununsed_speakers"]
|
||||||
|
language = dataset["language"]
|
||||||
|
|
||||||
# setup the right data processor
|
# setup the right data processor
|
||||||
if formatter is None:
|
if formatter is None:
|
||||||
formatter = _get_formatter_by_name(name)
|
formatter = _get_formatter_by_name(name)
|
||||||
# load train set
|
# load train set
|
||||||
meta_data_train = formatter(root_path, meta_file_train, ununsed_speakers=ununsed_speakers)
|
meta_data_train = formatter(root_path, meta_file_train, ununsed_speakers=ununsed_speakers)
|
||||||
|
# TODO: remove the loops and pass language as a parameter to preprocessor for faster load
|
||||||
|
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||||
|
|
||||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||||
# load evaluation split if set
|
# load evaluation split if set
|
||||||
if eval_split:
|
if eval_split:
|
||||||
if meta_file_val:
|
if meta_file_val:
|
||||||
meta_data_eval = formatter(root_path, meta_file_val, ununsed_speakers=ununsed_speakers)
|
meta_data_eval = formatter(root_path, meta_file_val, ununsed_speakers=ununsed_speakers)
|
||||||
|
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||||
else:
|
else:
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
|
|
|
@ -37,6 +37,7 @@ class TTSDataset(Dataset):
|
||||||
enable_eos_bos: bool = False,
|
enable_eos_bos: bool = False,
|
||||||
speaker_id_mapping: Dict = None,
|
speaker_id_mapping: Dict = None,
|
||||||
d_vector_mapping: Dict = None,
|
d_vector_mapping: Dict = None,
|
||||||
|
language_id_mapping: Dict = None,
|
||||||
use_noise_augment: bool = False,
|
use_noise_augment: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -122,6 +123,7 @@ class TTSDataset(Dataset):
|
||||||
self.enable_eos_bos = enable_eos_bos
|
self.enable_eos_bos = enable_eos_bos
|
||||||
self.speaker_id_mapping = speaker_id_mapping
|
self.speaker_id_mapping = speaker_id_mapping
|
||||||
self.d_vector_mapping = d_vector_mapping
|
self.d_vector_mapping = d_vector_mapping
|
||||||
|
self.language_id_mapping = language_id_mapping
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.input_seq_computed = False
|
self.input_seq_computed = False
|
||||||
|
@ -197,10 +199,10 @@ class TTSDataset(Dataset):
|
||||||
def load_data(self, idx):
|
def load_data(self, idx):
|
||||||
item = self.items[idx]
|
item = self.items[idx]
|
||||||
|
|
||||||
if len(item) == 4:
|
if len(item) == 5:
|
||||||
text, wav_file, speaker_name, attn_file = item
|
text, wav_file, speaker_name, language_name, attn_file = item
|
||||||
else:
|
else:
|
||||||
text, wav_file, speaker_name = item
|
text, wav_file, speaker_name, language_name = item
|
||||||
attn = None
|
attn = None
|
||||||
raw_text = text
|
raw_text = text
|
||||||
|
|
||||||
|
@ -218,7 +220,7 @@ class TTSDataset(Dataset):
|
||||||
self.phoneme_cache_path,
|
self.phoneme_cache_path,
|
||||||
self.enable_eos_bos,
|
self.enable_eos_bos,
|
||||||
self.cleaners,
|
self.cleaners,
|
||||||
self.phoneme_language,
|
language_name if language_name else self.phoneme_language,
|
||||||
self.custom_symbols,
|
self.custom_symbols,
|
||||||
self.characters,
|
self.characters,
|
||||||
self.add_blank,
|
self.add_blank,
|
||||||
|
@ -260,6 +262,7 @@ class TTSDataset(Dataset):
|
||||||
"attn": attn,
|
"attn": attn,
|
||||||
"item_idx": self.items[idx][1],
|
"item_idx": self.items[idx][1],
|
||||||
"speaker_name": speaker_name,
|
"speaker_name": speaker_name,
|
||||||
|
"language_name": language_name,
|
||||||
"wav_file_name": os.path.basename(wav_file),
|
"wav_file_name": os.path.basename(wav_file),
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
@ -413,6 +416,14 @@ class TTSDataset(Dataset):
|
||||||
# convert list of dicts to dict of lists
|
# convert list of dicts to dict of lists
|
||||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
|
|
||||||
|
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||||
|
|
||||||
|
# get language ids from language names
|
||||||
|
if self.language_id_mapping is not None:
|
||||||
|
language_names = [batch[idx]["language_name"] for idx in ids_sorted_decreasing]
|
||||||
|
language_ids = [self.language_id_mapping[ln] for ln in language_names]
|
||||||
|
else:
|
||||||
|
language_ids = None
|
||||||
# get pre-computed d-vectors
|
# get pre-computed d-vectors
|
||||||
if self.d_vector_mapping is not None:
|
if self.d_vector_mapping is not None:
|
||||||
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
wav_files_names = [batch["wav_file_name"][idx] for idx in ids_sorted_decreasing]
|
||||||
|
@ -466,6 +477,9 @@ class TTSDataset(Dataset):
|
||||||
if speaker_ids is not None:
|
if speaker_ids is not None:
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
|
||||||
|
if language_ids is not None:
|
||||||
|
language_ids = torch.LongTensor(language_ids)
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
|
||||||
|
@ -528,6 +542,7 @@ class TTSDataset(Dataset):
|
||||||
"waveform": wav_padded,
|
"waveform": wav_padded,
|
||||||
"raw_text": batch["raw_text"],
|
"raw_text": batch["raw_text"],
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
|
"language_ids": language_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
|
|
|
@ -13,6 +13,7 @@ from TTS.model import BaseModel
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
|
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
@ -73,9 +74,18 @@ class BaseTTS(BaseModel):
|
||||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||||
return get_speaker_manager(config, restore_path, data, out_path)
|
return get_speaker_manager(config, restore_path, data, out_path)
|
||||||
|
|
||||||
def init_multispeaker(self, config: Coqpit):
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
"""Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
vector dimension in the network. If model uses d-vectors, then it only sets the expected dimension.
|
`in_channels` size of the connected layers.
|
||||||
|
|
||||||
|
This implementation yields 3 possible outcomes:
|
||||||
|
|
||||||
|
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||||
|
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||||
|
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||||
|
`config.d_vector_dim` or 512.
|
||||||
|
|
||||||
|
You can override this function for new models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
|
@ -122,6 +132,7 @@ class BaseTTS(BaseModel):
|
||||||
attn_mask = batch["attns"]
|
attn_mask = batch["attns"]
|
||||||
waveform = batch["waveform"]
|
waveform = batch["waveform"]
|
||||||
pitch = batch["pitch"]
|
pitch = batch["pitch"]
|
||||||
|
language_ids = batch["language_ids"]
|
||||||
max_text_length = torch.max(text_lengths.float())
|
max_text_length = torch.max(text_lengths.float())
|
||||||
max_spec_length = torch.max(mel_lengths.float())
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
@ -169,6 +180,7 @@ class BaseTTS(BaseModel):
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
"waveform": waveform,
|
"waveform": waveform,
|
||||||
"pitch": pitch,
|
"pitch": pitch,
|
||||||
|
"language_ids": language_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
|
@ -199,7 +211,12 @@ class BaseTTS(BaseModel):
|
||||||
if hasattr(self, "make_symbols"):
|
if hasattr(self, "make_symbols"):
|
||||||
custom_symbols = self.make_symbols(self.config)
|
custom_symbols = self.make_symbols(self.config)
|
||||||
|
|
||||||
# init dataset
|
if hasattr(self, "language_manager"):
|
||||||
|
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||||
|
else:
|
||||||
|
language_id_mapping = None
|
||||||
|
|
||||||
|
# init dataloader
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
text_cleaner=config.text_cleaner,
|
||||||
|
@ -223,6 +240,7 @@ class BaseTTS(BaseModel):
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||||
|
language_id_mapping=language_id_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pre-compute phonemes
|
# pre-compute phonemes
|
||||||
|
@ -267,8 +285,11 @@ class BaseTTS(BaseModel):
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
if sampler is None:
|
||||||
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
|
sampler = get_language_weighted_sampler(dataset.items)
|
||||||
|
print(" > Using Language weighted sampler")
|
||||||
|
|
||||||
# init dataloader
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
|
|
@ -135,3 +135,11 @@ def phoneme_cleaners(text):
|
||||||
text = remove_aux_symbols(text)
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def multilingual_cleaners(text):
|
||||||
|
'''Pipeline for multilingual text'''
|
||||||
|
text = lowercase(text)
|
||||||
|
text = replace_symbols(text, lang=None)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
|
text = collapse_whitespace(text)
|
||||||
|
return text
|
Loading…
Reference in New Issue