mirror of https://github.com/coqui-ai/TTS.git
Compute d_vectors and speaker_ids separately in TTSDataset
This commit is contained in:
parent
f00ef90ce6
commit
a605dd3d08
|
@ -39,7 +39,8 @@ def setup_loader(ap, r, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_manager.speaker_ids
|
||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
||||
d_vector_mapping=speaker_manager.d_vectors
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
@ -84,22 +85,12 @@ def format_data(data):
|
|||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
d_vectors = data[8]
|
||||
speaker_ids = data[9]
|
||||
attn_mask = data[10]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
|
|
|
@ -267,7 +267,8 @@ class TrainerTTS:
|
|||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
speaker_mapping: Union[Dict, List],
|
||||
speaker_ids: Union[Dict, List],
|
||||
d_vectors: Union[Dict, List]
|
||||
) -> DataLoader:
|
||||
if is_eval and not self.config.run_eval:
|
||||
loader = None
|
||||
|
@ -289,9 +290,10 @@ class TrainerTTS:
|
|||
enable_eos_bos=self.config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
speaker_id_mapping=speaker_ids
|
||||
if self.config.use_speaker_embedding else None,
|
||||
d_vector_mapping=d_vectors
|
||||
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file else None,
|
||||
)
|
||||
|
||||
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
||||
|
@ -313,14 +315,14 @@ class TrainerTTS:
|
|||
return loader
|
||||
|
||||
def get_train_dataloader(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
|
||||
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
|
||||
|
||||
def get_eval_dataloder(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
|
||||
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
|
||||
|
||||
def format_batch(self, batch: List) -> Dict:
|
||||
# setup input batch
|
||||
|
@ -332,24 +334,12 @@ class TrainerTTS:
|
|||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
item_idx = batch[7]
|
||||
speaker_embeddings = batch[8]
|
||||
attn_mask = batch[9]
|
||||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# convert speaker names to ids
|
||||
if self.config.use_speaker_embedding:
|
||||
if self.config.use_external_speaker_embedding_file:
|
||||
speaker_embeddings = batch[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
|
@ -640,11 +630,11 @@ class TrainerTTS:
|
|||
|
||||
# define data loaders
|
||||
self.train_loader = self.get_train_dataloader(
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
|
||||
)
|
||||
self.eval_loader = (
|
||||
self.get_eval_dataloder(
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
|
||||
)
|
||||
if self.config.run_eval
|
||||
else None
|
||||
|
|
|
@ -29,7 +29,8 @@ class TTSDataset(Dataset):
|
|||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
speaker_mapping=None,
|
||||
speaker_id_mapping=None,
|
||||
d_vector_mapping=None,
|
||||
use_noise_augment=False,
|
||||
verbose=False,
|
||||
):
|
||||
|
@ -51,6 +52,8 @@ class TTSDataset(Dataset):
|
|||
phoneme_language (str): one the languages from
|
||||
https://github.com/bootphon/phonemizer#languages
|
||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
|
||||
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
|
||||
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
|
@ -70,7 +73,8 @@ class TTSDataset(Dataset):
|
|||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_mapping = speaker_mapping
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
|
@ -293,13 +297,18 @@ class TTSDataset(Dataset):
|
|||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||
|
||||
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get speaker embeddings
|
||||
if self.speaker_mapping is not None:
|
||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get pre-computed d-vectors
|
||||
if self.d_vector_mapping is not None:
|
||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
||||
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
|
||||
d_vectors = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
speaker_embedding = None
|
||||
d_vectors = None
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_id_mapping:
|
||||
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in speaker_names]
|
||||
else:
|
||||
speaker_ids = None
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
||||
|
||||
|
@ -327,8 +336,11 @@ class TTSDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = torch.FloatTensor(speaker_embedding)
|
||||
if d_vectors is not None:
|
||||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
|
@ -355,13 +367,14 @@ class TTSDataset(Dataset):
|
|||
return (
|
||||
text,
|
||||
text_lenghts,
|
||||
speaker_name,
|
||||
speaker_names,
|
||||
linear,
|
||||
mel,
|
||||
mel_lengths,
|
||||
stop_targets,
|
||||
item_idxs,
|
||||
speaker_embedding,
|
||||
d_vectors,
|
||||
speaker_ids,
|
||||
attns,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue