From e36a3067e4af08f6990532140e910df911638a88 Mon Sep 17 00:00:00 2001 From: mueller Date: Thu, 17 Sep 2020 14:14:30 +0200 Subject: [PATCH] add: save wavs instead feats to storage. This is done in order to mitigate staleness when caching and loading from data storage --- TTS/bin/train_encoder.py | 2 +- TTS/speaker_encoder/dataset.py | 20 ++++++++++++-------- TTS/tts/datasets/preprocess.py | 3 +-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index e73e1614..56a2b954 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -111,7 +111,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, - "loader_time": loader_time + "avg_loader_time": avg_loader_time } tb_logger.tb_train_epoch_stats(global_step, train_stats) figures = { diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 31413e7e..3f3db88d 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -110,7 +110,7 @@ class MyDataset(Dataset): """ Sample all M utterances for the given speaker. """ - feats = [] + wavs = [] labels = [] for _ in range(self.num_utter_per_speaker): # TODO:dummy but works @@ -126,11 +126,9 @@ class MyDataset(Dataset): break self.speaker_to_utters[speaker].remove(utter) - offset = random.randint(0, wav.shape[0] - self.seq_len) - mel = self.ap.melspectrogram(wav[offset : offset + self.seq_len]) - feats.append(torch.FloatTensor(mel)) + wavs.append(wav) labels.append(speaker) - return feats, labels + return wavs, labels def __getitem__(self, idx): speaker, _ = self.__sample_speaker() @@ -142,15 +140,21 @@ class MyDataset(Dataset): for speaker in batch: if random.random() < self.sample_from_storage_p and self.storage.full(): # sample from storage (if full), ignoring the speaker - feats_, labels_ = random.choice(self.storage.queue) + wavs_, labels_ = random.choice(self.storage.queue) else: # don't sample from storage, but from HDD - feats_, labels_ = self.__sample_speaker_utterances(speaker) + wavs_, labels_ = self.__sample_speaker_utterances(speaker) # if storage is full, remove an item if self.storage.full(): _ = self.storage.get_nowait() # put the newly loaded item into storage - self.storage.put_nowait((feats_, labels_)) + self.storage.put_nowait((wavs_, labels_)) + + # get a random subset of each of the wavs and convert to MFCC. + offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] + mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] + feats_ = [torch.FloatTensor(mel) for mel in mels_] + labels.append(labels_) feats.extend(feats_) feats = torch.stack(feats) diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 4b2903a0..73a56774 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -17,10 +17,9 @@ def load_meta_data(datasets): root_path = dataset['path'] meta_file_train = dataset['meta_file_train'] meta_file_val = dataset['meta_file_val'] - print(f" | > Preprocessing {name}") preprocessor = get_preprocessor_by_name(name) meta_data_train = preprocessor(root_path, meta_file_train) - print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") if meta_file_val is None: meta_data_eval, meta_data_train = split_dataset(meta_data_train) else: