mirror of https://github.com/coqui-ai/TTS.git
=Filter instances by sequence length
This commit is contained in:
parent
f6abc2acc3
commit
5f4393330f
|
@ -12,9 +12,9 @@
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
"lr": 0.0003,
|
"lr": 0.0006 / 32,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 8,
|
"batch_size": 1,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
|
@ -25,5 +25,6 @@
|
||||||
"checkpoint": false,
|
"checkpoint": false,
|
||||||
"save_step": 69,
|
"save_step": 69,
|
||||||
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
||||||
|
"min_seq_len": 90,
|
||||||
"output_path": "result"
|
"output_path": "result"
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,8 @@ class LJSpeechDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
|
min_seq_len=0):
|
||||||
|
|
||||||
with open(csv_file, "r") as f:
|
with open(csv_file, "r") as f:
|
||||||
self.frames = [line.split('|') for line in f]
|
self.frames = [line.split('|') for line in f]
|
||||||
|
@ -22,6 +23,7 @@ class LJSpeechDataset(Dataset):
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
|
self.min_seq_length = min_seq_length
|
||||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
|
@ -45,8 +47,14 @@ class LJSpeechDataset(Dataset):
|
||||||
|
|
||||||
idxs = np.argsort(lengths)
|
idxs = np.argsort(lengths)
|
||||||
new_frames = [None] * len(lengths)
|
new_frames = [None] * len(lengths)
|
||||||
|
ignored = []
|
||||||
for i, idx in enumerate(idxs):
|
for i, idx in enumerate(idxs):
|
||||||
|
length = lengths[idx]
|
||||||
|
if length < self.min_seq_length:
|
||||||
|
ignored.append(idx)
|
||||||
|
else
|
||||||
new_frames[i] = self.frames[idx]
|
new_frames[i] = self.frames[idx]
|
||||||
|
print(" | > {} instances are ignored by min_seq_len ({})".format(len(ignored), self.min_seq_len))
|
||||||
self.frames = new_frames
|
self.frames = new_frames
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
3
train.py
3
train.py
|
@ -302,7 +302,8 @@ def main(args):
|
||||||
c.preemphasis,
|
c.preemphasis,
|
||||||
c.ref_level_db,
|
c.ref_level_db,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.power
|
c.power,
|
||||||
|
min_seq_len=c.min_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||||
|
|
Loading…
Reference in New Issue