diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index d8eb66b4..e72a66e0 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -3,6 +3,7 @@ import numpy as np import collections import librosa import torch +import random from torch.utils.data import Dataset from utils.text import text_to_sequence @@ -17,8 +18,10 @@ class MyDataset(Dataset): outputs_per_step, text_cleaner, ap, + batch_group_size=0, min_seq_len=0): self.root_dir = root_dir + self.batch_group_size = batch_group_size self.wav_dir = os.path.join(root_dir, 'wavs') self.csv_dir = os.path.join(root_dir, csv_file) with open(self.csv_dir, "r", encoding="utf8") as f: @@ -30,7 +33,7 @@ class MyDataset(Dataset): self.ap = ap print(" > Reading LJSpeech from - {}".format(root_dir)) print(" | > Number of instances : {}".format(len(self.frames))) - self._sort_frames() + self.sort_frames() def load_wav(self, filename): try: @@ -39,8 +42,8 @@ class MyDataset(Dataset): except RuntimeError as e: print(" !! Cannot read file : {}".format(filename)) - def _sort_frames(self): - r"""Sort sequences in ascending order""" + def sort_frames(self): + r"""Sort text sequences in ascending order""" lengths = np.array([len(ins[1]) for ins in self.frames]) print(" | > Max length sequence {}".format(np.max(lengths))) @@ -58,6 +61,15 @@ class MyDataset(Dataset): new_frames.append(self.frames[idx]) print(" | > {} instances are ignored by min_seq_len ({})".format( len(ignored), self.min_seq_len)) + # shuffle batch groups + if self.batch_group_size > 0: + print(" | > Batch group shuffling is active.") + for i in range(len(new_frames) // self.batch_group_size): + offset = i * self.batch_group_size + end_offset = offset + self.batch_group_size + temp_frames = new_frames[offset : end_offset] + random.shuffle(temp_frames) + new_frames[offset : end_offset] = temp_frames self.frames = new_frames def __len__(self): diff --git a/tests/loader_tests.py b/tests/loader_tests.py index a53cf635..b80f3e74 100644 --- a/tests/loader_tests.py +++ b/tests/loader_tests.py @@ -66,6 +66,49 @@ class TestLJSpeechDataset(unittest.TestCase): assert mel_input.shape[0] == c.batch_size assert mel_input.shape[2] == c.num_mels + def test_batch_group_shuffle(self): + if ok_ljspeech: + dataset = LJSpeech.MyDataset( + os.path.join(c.data_path_LJSpeech), + os.path.join(c.data_path_LJSpeech, 'metadata.csv'), + c.r, + c.text_cleaner, + ap=self.ap, + batch_group_size=16, + min_seq_len=c.min_seq_len) + + dataloader = DataLoader( + dataset, + batch_size=2, + shuffle=True, + collate_fn=dataset.collate_fn, + drop_last=True, + num_workers=c.num_loader_workers) + + frames = dataset.frames + for i, data in enumerate(dataloader): + if i == self.max_loader_iter: + break + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] + stop_target = data[5] + item_idx = data[6] + + neg_values = text_input[text_input < 0] + check_count = len(neg_values) + assert check_count == 0, \ + " !! Negative values in text_input: {}".format(check_count) + # TODO: more assertion here + assert linear_input.shape[0] == c.batch_size + assert mel_input.shape[0] == c.batch_size + assert mel_input.shape[2] == c.num_mels + dataloader.dataset.sort_frames() + assert frames[0] != dataloader.dataset.frames[0] + + def test_padding(self): if ok_ljspeech: dataset = LJSpeech.MyDataset( diff --git a/train.py b/train.py index e1df221d..d8a75b1d 100644 --- a/train.py +++ b/train.py @@ -191,7 +191,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0 - return avg_linear_loss, current_step @@ -361,6 +360,7 @@ def main(args): c.r, c.text_cleaner, ap=ap, + batch_group_size=16*c.batch_size, min_seq_len=c.min_seq_len) train_loader = DataLoader( @@ -374,7 +374,7 @@ def main(args): if c.run_eval: val_dataset = Dataset( - c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap) + c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0) val_loader = DataLoader( val_dataset, @@ -444,6 +444,8 @@ def main(args): flush=True) best_loss = save_best_model(model, optimizer, train_loss, best_loss, OUT_PATH, current_step, epoch) + # shuffle batch groups + train_loader.dataset.sort_frames() if __name__ == '__main__':