Compute d_vectors and speaker_ids separately in TTSDataset

This commit is contained in:
Eren Gölge 2021-06-03 11:46:53 +02:00
parent db6a97d1a2
commit 802d461389
3 changed files with 44 additions and 50 deletions

View File

@ -39,7 +39,8 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False,
verbose=verbose,
speaker_mapping=speaker_manager.speaker_ids
speaker_id_mapping=speaker_manager.speaker_ids,
d_vector_mapping=speaker_manager.d_vectors
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
)
@ -84,22 +85,12 @@ def format_data(data):
mel_input = data[4]
mel_lengths = data[5]
item_idx = data[7]
attn_mask = data[9]
d_vectors = data[8]
speaker_ids = data[9]
attn_mask = data[10]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embeddings = data[8]
speaker_ids = None
else:
speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)

View File

@ -267,7 +267,8 @@ class TrainerTTS:
is_eval: bool,
data_items: List,
verbose: bool,
speaker_mapping: Union[Dict, List],
speaker_ids: Union[Dict, List],
d_vectors: Union[Dict, List]
) -> DataLoader:
if is_eval and not self.config.run_eval:
loader = None
@ -289,9 +290,10 @@ class TrainerTTS:
enable_eos_bos=self.config.enable_eos_bos_chars,
use_noise_augment=not is_eval,
verbose=verbose,
speaker_mapping=speaker_mapping
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
else None,
speaker_id_mapping=speaker_ids
if self.config.use_speaker_embedding else None,
d_vector_mapping=d_vectors
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file else None,
)
if self.config.use_phonemes and self.config.compute_input_seq_cache:
@ -313,14 +315,14 @@ class TrainerTTS:
return loader
def get_train_dataloader(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
def get_eval_dataloder(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
def format_batch(self, batch: List) -> Dict:
# setup input batch
@ -332,24 +334,12 @@ class TrainerTTS:
mel_lengths = batch[5]
stop_targets = batch[6]
item_idx = batch[7]
speaker_embeddings = batch[8]
attn_mask = batch[9]
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
# convert speaker names to ids
if self.config.use_speaker_embedding:
if self.config.use_external_speaker_embedding_file:
speaker_embeddings = batch[8]
speaker_ids = None
else:
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# compute durations from attention masks
durations = None
if attn_mask is not None:
@ -640,11 +630,11 @@ class TrainerTTS:
# define data loaders
self.train_loader = self.get_train_dataloader(
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
)
self.eval_loader = (
self.get_eval_dataloder(
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
)
if self.config.run_eval
else None

View File

@ -29,7 +29,8 @@ class TTSDataset(Dataset):
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
speaker_mapping=None,
speaker_id_mapping=None,
d_vector_mapping=None,
use_noise_augment=False,
verbose=False,
):
@ -51,6 +52,8 @@ class TTSDataset(Dataset):
phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
use_noise_augment (bool): enable adding random noise to wav for augmentation.
verbose (bool): print diagnostic information.
"""
@ -70,7 +73,8 @@ class TTSDataset(Dataset):
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos
self.speaker_mapping = speaker_mapping
self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.use_noise_augment = use_noise_augment
self.verbose = verbose
self.input_seq_computed = False
@ -293,13 +297,18 @@ class TTSDataset(Dataset):
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get speaker embeddings
if self.speaker_mapping is not None:
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get pre-computed d-vectors
if self.d_vector_mapping is not None:
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
d_vectors = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
else:
speaker_embedding = None
d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in speaker_names]
else:
speaker_ids = None
# compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
@ -327,8 +336,11 @@ class TTSDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
if speaker_embedding is not None:
speaker_embedding = torch.FloatTensor(speaker_embedding)
if d_vectors is not None:
d_vectors = torch.FloatTensor(d_vectors)
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
# compute linear spectrogram
if self.compute_linear_spec:
@ -355,13 +367,14 @@ class TTSDataset(Dataset):
return (
text,
text_lenghts,
speaker_name,
speaker_names,
linear,
mel,
mel_lengths,
stop_targets,
item_idxs,
speaker_embedding,
d_vectors,
speaker_ids,
attns,
)