Compute d_vectors and speaker_ids separately in TTSDataset

This commit is contained in:
Eren Gölge 2021-06-03 11:46:53 +02:00
parent db6a97d1a2
commit 802d461389
3 changed files with 44 additions and 50 deletions

View File

@ -39,7 +39,8 @@ def setup_loader(ap, r, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars, enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=False, use_noise_augment=False,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_manager.speaker_ids speaker_id_mapping=speaker_manager.speaker_ids,
d_vector_mapping=speaker_manager.d_vectors
if c.use_speaker_embedding and c.use_external_speaker_embedding_file if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None, else None,
) )
@ -84,22 +85,12 @@ def format_data(data):
mel_input = data[4] mel_input = data[4]
mel_lengths = data[5] mel_lengths = data[5]
item_idx = data[7] item_idx = data[7]
attn_mask = data[9] d_vectors = data[8]
speaker_ids = data[9]
attn_mask = data[10]
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embeddings = data[8]
speaker_ids = None
else:
speaker_ids = [speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda(non_blocking=True) text_input = text_input.cuda(non_blocking=True)

View File

@ -267,7 +267,8 @@ class TrainerTTS:
is_eval: bool, is_eval: bool,
data_items: List, data_items: List,
verbose: bool, verbose: bool,
speaker_mapping: Union[Dict, List], speaker_ids: Union[Dict, List],
d_vectors: Union[Dict, List]
) -> DataLoader: ) -> DataLoader:
if is_eval and not self.config.run_eval: if is_eval and not self.config.run_eval:
loader = None loader = None
@ -289,9 +290,10 @@ class TrainerTTS:
enable_eos_bos=self.config.enable_eos_bos_chars, enable_eos_bos=self.config.enable_eos_bos_chars,
use_noise_augment=not is_eval, use_noise_augment=not is_eval,
verbose=verbose, verbose=verbose,
speaker_mapping=speaker_mapping speaker_id_mapping=speaker_ids
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file if self.config.use_speaker_embedding else None,
else None, d_vector_mapping=d_vectors
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file else None,
) )
if self.config.use_phonemes and self.config.compute_input_seq_cache: if self.config.use_phonemes and self.config.compute_input_seq_cache:
@ -313,14 +315,14 @@ class TrainerTTS:
return loader return loader
def get_train_dataloader( def get_train_dataloader(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader: ) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping) return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
def get_eval_dataloder( def get_eval_dataloder(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict] self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
) -> DataLoader: ) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping) return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
def format_batch(self, batch: List) -> Dict: def format_batch(self, batch: List) -> Dict:
# setup input batch # setup input batch
@ -332,24 +334,12 @@ class TrainerTTS:
mel_lengths = batch[5] mel_lengths = batch[5]
stop_targets = batch[6] stop_targets = batch[6]
item_idx = batch[7] item_idx = batch[7]
speaker_embeddings = batch[8] d_vectors = batch[8]
attn_mask = batch[9] speaker_ids = batch[9]
attn_mask = batch[10]
max_text_length = torch.max(text_lengths.float()) max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float()) max_spec_length = torch.max(mel_lengths.float())
# convert speaker names to ids
if self.config.use_speaker_embedding:
if self.config.use_external_speaker_embedding_file:
speaker_embeddings = batch[8]
speaker_ids = None
else:
speaker_ids = [self.speaker_manager.speaker_ids[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# compute durations from attention masks # compute durations from attention masks
durations = None durations = None
if attn_mask is not None: if attn_mask is not None:
@ -640,11 +630,11 @@ class TrainerTTS:
# define data loaders # define data loaders
self.train_loader = self.get_train_dataloader( self.train_loader = self.get_train_dataloader(
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
) )
self.eval_loader = ( self.eval_loader = (
self.get_eval_dataloder( self.get_eval_dataloder(
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids self.config.r, self.ap, self.data_train, verbose=True, speaker_ids=self.speaker_manager.speaker_ids, d_vectors=self.speaker_manager.d_vectors
) )
if self.config.run_eval if self.config.run_eval
else None else None

View File

@ -29,7 +29,8 @@ class TTSDataset(Dataset):
phoneme_cache_path=None, phoneme_cache_path=None,
phoneme_language="en-us", phoneme_language="en-us",
enable_eos_bos=False, enable_eos_bos=False,
speaker_mapping=None, speaker_id_mapping=None,
d_vector_mapping=None,
use_noise_augment=False, use_noise_augment=False,
verbose=False, verbose=False,
): ):
@ -51,6 +52,8 @@ class TTSDataset(Dataset):
phoneme_language (str): one the languages from phoneme_language (str): one the languages from
https://github.com/bootphon/phonemizer#languages https://github.com/bootphon/phonemizer#languages
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters. enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
speaker_id_mapping (dict): list of speaker ids to map speaker names to numerical ids.
d_vector_mapping (dict): dictionary of d-vectors that maps each audio file to a pre-computed d-vector.
use_noise_augment (bool): enable adding random noise to wav for augmentation. use_noise_augment (bool): enable adding random noise to wav for augmentation.
verbose (bool): print diagnostic information. verbose (bool): print diagnostic information.
""" """
@ -70,7 +73,8 @@ class TTSDataset(Dataset):
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language self.phoneme_language = phoneme_language
self.enable_eos_bos = enable_eos_bos self.enable_eos_bos = enable_eos_bos
self.speaker_mapping = speaker_mapping self.speaker_id_mapping = speaker_id_mapping
self.d_vector_mapping = d_vector_mapping
self.use_noise_augment = use_noise_augment self.use_noise_augment = use_noise_augment
self.verbose = verbose self.verbose = verbose
self.input_seq_computed = False self.input_seq_computed = False
@ -293,13 +297,18 @@ class TTSDataset(Dataset):
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing] text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get speaker embeddings # get pre-computed d-vectors
if self.speaker_mapping is not None: if self.d_vector_mapping is not None:
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names] d_vectors = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
else: else:
speaker_embedding = None d_vectors = None
# get numerical speaker ids from speaker names
if self.speaker_id_mapping:
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in speaker_names]
else:
speaker_ids = None
# compute features # compute features
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
@ -327,8 +336,11 @@ class TTSDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
if speaker_embedding is not None: if d_vectors is not None:
speaker_embedding = torch.FloatTensor(speaker_embedding) d_vectors = torch.FloatTensor(d_vectors)
if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
# compute linear spectrogram # compute linear spectrogram
if self.compute_linear_spec: if self.compute_linear_spec:
@ -355,13 +367,14 @@ class TTSDataset(Dataset):
return ( return (
text, text,
text_lenghts, text_lenghts,
speaker_name, speaker_names,
linear, linear,
mel, mel,
mel_lengths, mel_lengths,
stop_targets, stop_targets,
item_idxs, item_idxs,
speaker_embedding, d_vectors,
speaker_ids,
attns, attns,
) )