A big revision: visualization, data loader, tests

This commit is contained in:
Eren Golge 2018-02-04 08:25:00 -08:00
parent 86f3a56f59
commit 9f5e102473
14 changed files with 199 additions and 64 deletions

Binary file not shown.

View File

@ -56,7 +56,7 @@ class LJSpeechDataset(Dataset):
keys = list()
text = [d['text'] for d in batch]
text_lenghts = [len(x) for x in text]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
wav = [d['wav'] for d in batch]
@ -77,6 +77,10 @@ class LJSpeechDataset(Dataset):
magnitude = magnitude.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
magnitude = torch.FloatTensor(magnitude)
mel = torch.FloatTensor(mel)
return text, text_lenghts, magnitude, mel
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\

0
datasets/__init__.py Normal file
View File

View File

@ -15,14 +15,14 @@
"lr": 0.01,
"lr_patience": 2,
"lr_decay": 0.5,
"batch_size": 16,
"batch_size": 32,
"griffinf_lim_iters": 60,
"power": 1.5,
"r": 5,
"num_loader_workers": 16,
"save_step": 1,
"save_step":1 ,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"

View File

@ -89,7 +89,7 @@ class CBHG(nn.Module):
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
def forward(self, inputs, input_lengths=None):
def forward(self, inputs):
# (B, T_in, in_dim)
x = inputs
@ -121,17 +121,19 @@ class CBHG(nn.Module):
for highway in self.highways:
x = highway(x)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
# if input_lengths is not None:
# print(x.size())
# print(len(input_lengths))
# x = nn.utils.rnn.pack_padded_sequence(
# x, input_lengths.data.cpu().numpy(), batch_first=True)
# (B, T_in, in_dim*2)
self.gru.flatten_parameters()
outputs, _ = self.gru(x)
if input_lengths is not None:
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
#if input_lengths is not None:
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
# outputs, batch_first=True)
return outputs
@ -142,9 +144,9 @@ class Encoder(nn.Module):
self.prenet = Prenet(in_dim, sizes=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs, input_lengths=None):
def forward(self, inputs):
inputs = self.prenet(inputs)
return self.cbhg(inputs, input_lengths)
return self.cbhg(inputs)
class Decoder(nn.Module):

View File

@ -30,7 +30,7 @@ class Tacotron(nn.Module):
inputs = self.embedding(characters)
# (B, T', in_dim)
encoder_outputs = self.encoder(inputs, input_lengths)
encoder_outputs = self.encoder(inputs)
if self.use_memory_mask:
memory_lengths = input_lengths

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
tests/layers_tests.py Normal file
View File

@ -0,0 +1 @@
import unittest

55
tests/loader_tests.py Normal file
View File

@ -0,0 +1,55 @@
import os
import unittest
import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.datasets.LJSpeech import LJSpeechDataset
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
class TestDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
def test_loader(self):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
print(text_lengths)
magnitude_input = data[2]
mel_input = data[3]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)

0
tests/tacotron_tests.py Normal file
View File

30
tests/test_config.json Normal file
View File

@ -0,0 +1,30 @@
{
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 20000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 8,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 4,
"save_step": 200,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}

View File

@ -22,6 +22,7 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
load_config, lr_decay)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
@ -117,7 +118,6 @@ def main(args):
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
# patience=c.lr_patience, verbose=True)
epoch_time = 0
for epoch in range(c.epochs):
@ -133,6 +133,7 @@ def main(args):
mel_input = data[3]
current_step = i + args.restore_step + epoch * len(dataloader) + 1
print(current_step)
# setup lr
current_lr = lr_decay(c.lr, current_step)
@ -148,32 +149,28 @@ def main(args):
#except:
# raise TypeError("not same dimension")
if use_cuda:
text_input_var = Variable(torch.from_numpy(text_input).type(
torch.cuda.LongTensor)).cuda()
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
torch.cuda.LongTensor)).cuda()
mel_input_var = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor)).cuda()
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor)).cuda()
linear_spec_var = Variable(torch.from_numpy(magnitude_input)
.type(torch.cuda.FloatTensor)).cuda()
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(magnitude_input, volatile=True)
else:
text_input_var = Variable(torch.from_numpy(text_input).type(
torch.LongTensor),)
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
torch.LongTensor))
mel_input_var = Variable(torch.from_numpy(mel_input).type(
torch.FloatTensor))
mel_spec_var = Variable(torch.from_numpy(
mel_input).type(torch.FloatTensor))
linear_spec_var = Variable(torch.from_numpy(
magnitude_input).type(torch.FloatTensor))
# sort sequence by length. Pytorch needs this.
sorted_lengths, indices = torch.sort(
text_lengths.view(-1), dim=0, descending=True)
sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices]
linear_spec_var = linear_spec_var[indices]
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
linear_spec_var = linear_spec_var.cuda()
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var)
model.forward(text_input_var, mel_spec_var,
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
mel_loss = criterion(mel_output, mel_spec_var)
#linear_loss = torch.abs(linear_output - linear_spec_var)
@ -198,7 +195,6 @@ def main(args):
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
# Plot Learning Stats
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
@ -209,6 +205,9 @@ def main(args):
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
if current_step % c.save_step == 0:
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
@ -224,13 +223,24 @@ def main(args):
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu()[None, :]
gt_spec = linear_spec_var[0].data.cpu()[None, :]
align_img = alignments[0].data.cpu().t()[None, :]
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, dataset.ap)
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
tb.add_image('Spec/Reconstruction', const_spec, current_step)
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
dataset.ap.griffin_lim_iters = 60
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
#lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
@ -240,7 +250,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=128)
help='Global step to restore checkpoint', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
args = parser.parse_args()

35
utils/visual.py Normal file
View File

@ -0,0 +1,35 @@
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_spectrogram(linear_output, audio):
spectrogram = audio._denormalize(linear_output)
fig = plt.figure(figsize=(16, 10))
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data