=Filter instances by sequence length

This commit is contained in:
Eren Golge 2018-03-09 09:46:47 -08:00
parent f6abc2acc3
commit 5f4393330f
3 changed files with 15 additions and 5 deletions

View File

@ -12,9 +12,9 @@
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 2000, "epochs": 2000,
"lr": 0.0003, "lr": 0.0006 / 32,
"warmup_steps": 4000, "warmup_steps": 4000,
"batch_size": 8, "batch_size": 1,
"r": 5, "r": 5,
"griffin_lim_iters": 60, "griffin_lim_iters": 60,
@ -25,5 +25,6 @@
"checkpoint": false, "checkpoint": false,
"save_step": 69, "save_step": 69,
"data_path": "/run/shm/erogol/LJSpeech-1.0", "data_path": "/run/shm/erogol/LJSpeech-1.0",
"min_seq_len": 90,
"output_path": "result" "output_path": "result"
} }

View File

@ -14,7 +14,8 @@ class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
text_cleaner, num_mels, min_level_db, frame_shift_ms, text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power): frame_length_ms, preemphasis, ref_level_db, num_freq, power,
min_seq_len=0):
with open(csv_file, "r") as f: with open(csv_file, "r") as f:
self.frames = [line.split('|') for line in f] self.frames = [line.split('|') for line in f]
@ -22,6 +23,7 @@ class LJSpeechDataset(Dataset):
self.outputs_per_step = outputs_per_step self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.cleaners = text_cleaner self.cleaners = text_cleaner
self.min_seq_length = min_seq_length
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power) frame_length_ms, preemphasis, ref_level_db, num_freq, power)
print(" > Reading LJSpeech from - {}".format(root_dir)) print(" > Reading LJSpeech from - {}".format(root_dir))
@ -45,8 +47,14 @@ class LJSpeechDataset(Dataset):
idxs = np.argsort(lengths) idxs = np.argsort(lengths)
new_frames = [None] * len(lengths) new_frames = [None] * len(lengths)
ignored = []
for i, idx in enumerate(idxs): for i, idx in enumerate(idxs):
new_frames[i] = self.frames[idx] length = lengths[idx]
if length < self.min_seq_length:
ignored.append(idx)
else
new_frames[i] = self.frames[idx]
print(" | > {} instances are ignored by min_seq_len ({})".format(len(ignored), self.min_seq_len))
self.frames = new_frames self.frames = new_frames
def __len__(self): def __len__(self):

View File

@ -302,7 +302,8 @@ def main(args):
c.preemphasis, c.preemphasis,
c.ref_level_db, c.ref_level_db,
c.num_freq, c.num_freq,
c.power c.power,
min_seq_len=c.min_seq_len
) )
train_loader = DataLoader(train_dataset, batch_size=c.batch_size, train_loader = DataLoader(train_dataset, batch_size=c.batch_size,