mirror of https://github.com/coqui-ai/TTS.git
Batch group shuffling
This commit is contained in:
parent
a165cd7bda
commit
30fea0b957
|
@ -3,6 +3,7 @@ import numpy as np
|
||||||
import collections
|
import collections
|
||||||
import librosa
|
import librosa
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from utils.text import text_to_sequence
|
from utils.text import text_to_sequence
|
||||||
|
@ -17,8 +18,10 @@ class MyDataset(Dataset):
|
||||||
outputs_per_step,
|
outputs_per_step,
|
||||||
text_cleaner,
|
text_cleaner,
|
||||||
ap,
|
ap,
|
||||||
|
batch_group_size=0,
|
||||||
min_seq_len=0):
|
min_seq_len=0):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
|
self.batch_group_size = batch_group_size
|
||||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||||
|
@ -30,7 +33,7 @@ class MyDataset(Dataset):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
self._sort_frames()
|
self.sort_frames()
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
try:
|
try:
|
||||||
|
@ -39,8 +42,8 @@ class MyDataset(Dataset):
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
def _sort_frames(self):
|
def sort_frames(self):
|
||||||
r"""Sort sequences in ascending order"""
|
r"""Sort text sequences in ascending order"""
|
||||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||||
|
|
||||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
@ -58,6 +61,15 @@ class MyDataset(Dataset):
|
||||||
new_frames.append(self.frames[idx])
|
new_frames.append(self.frames[idx])
|
||||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||||
len(ignored), self.min_seq_len))
|
len(ignored), self.min_seq_len))
|
||||||
|
# shuffle batch groups
|
||||||
|
if self.batch_group_size > 0:
|
||||||
|
print(" | > Batch group shuffling is active.")
|
||||||
|
for i in range(len(new_frames) // self.batch_group_size):
|
||||||
|
offset = i * self.batch_group_size
|
||||||
|
end_offset = offset + self.batch_group_size
|
||||||
|
temp_frames = new_frames[offset : end_offset]
|
||||||
|
random.shuffle(temp_frames)
|
||||||
|
new_frames[offset : end_offset] = temp_frames
|
||||||
self.frames = new_frames
|
self.frames = new_frames
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|
|
@ -66,6 +66,49 @@ class TestLJSpeechDataset(unittest.TestCase):
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.num_mels
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
|
def test_batch_group_shuffle(self):
|
||||||
|
if ok_ljspeech:
|
||||||
|
dataset = LJSpeech.MyDataset(
|
||||||
|
os.path.join(c.data_path_LJSpeech),
|
||||||
|
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||||
|
c.r,
|
||||||
|
c.text_cleaner,
|
||||||
|
ap=self.ap,
|
||||||
|
batch_group_size=16,
|
||||||
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True,
|
||||||
|
num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
frames = dataset.frames
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
|
neg_values = text_input[text_input < 0]
|
||||||
|
check_count = len(neg_values)
|
||||||
|
assert check_count == 0, \
|
||||||
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
|
# TODO: more assertion here
|
||||||
|
assert linear_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
dataloader.dataset.sort_frames()
|
||||||
|
assert frames[0] != dataloader.dataset.frames[0]
|
||||||
|
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
if ok_ljspeech:
|
if ok_ljspeech:
|
||||||
dataset = LJSpeech.MyDataset(
|
dataset = LJSpeech.MyDataset(
|
||||||
|
|
6
train.py
6
train.py
|
@ -191,7 +191,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
|
||||||
return avg_linear_loss, current_step
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
|
@ -361,6 +360,7 @@ def main(args):
|
||||||
c.r,
|
c.r,
|
||||||
c.text_cleaner,
|
c.text_cleaner,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
|
batch_group_size=16*c.batch_size,
|
||||||
min_seq_len=c.min_seq_len)
|
min_seq_len=c.min_seq_len)
|
||||||
|
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
@ -374,7 +374,7 @@ def main(args):
|
||||||
|
|
||||||
if c.run_eval:
|
if c.run_eval:
|
||||||
val_dataset = Dataset(
|
val_dataset = Dataset(
|
||||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
|
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0)
|
||||||
|
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
|
@ -444,6 +444,8 @@ def main(args):
|
||||||
flush=True)
|
flush=True)
|
||||||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||||
OUT_PATH, current_step, epoch)
|
OUT_PATH, current_step, epoch)
|
||||||
|
# shuffle batch groups
|
||||||
|
train_loader.dataset.sort_frames()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue