mirror of https://github.com/coqui-ai/TTS.git
A big revision: visualization, data loader, tests
This commit is contained in:
parent
86f3a56f59
commit
9f5e102473
Binary file not shown.
|
@ -56,7 +56,7 @@ class LJSpeechDataset(Dataset):
|
|||
keys = list()
|
||||
|
||||
text = [d['text'] for d in batch]
|
||||
text_lenghts = [len(x) for x in text]
|
||||
text_lenghts = np.array([len(x) for x in text])
|
||||
max_text_len = np.max(text_lenghts)
|
||||
wav = [d['wav'] for d in batch]
|
||||
|
||||
|
@ -77,6 +77,10 @@ class LJSpeechDataset(Dataset):
|
|||
magnitude = magnitude.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
text_lenghts = torch.LongTensor(text_lenghts)
|
||||
text = torch.LongTensor(text)
|
||||
magnitude = torch.FloatTensor(magnitude)
|
||||
mel = torch.FloatTensor(mel)
|
||||
return text, text_lenghts, magnitude, mel
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
|
|
|
@ -15,14 +15,14 @@
|
|||
"lr": 0.01,
|
||||
"lr_patience": 2,
|
||||
"lr_decay": 0.5,
|
||||
"batch_size": 16,
|
||||
"batch_size": 32,
|
||||
"griffinf_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
"r": 5,
|
||||
|
||||
"num_loader_workers": 16,
|
||||
|
||||
"save_step": 1,
|
||||
"save_step":1 ,
|
||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"output_path": "result",
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
|
|
|
@ -89,7 +89,7 @@ class CBHG(nn.Module):
|
|||
self.gru = nn.GRU(
|
||||
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
def forward(self, inputs):
|
||||
# (B, T_in, in_dim)
|
||||
x = inputs
|
||||
|
||||
|
@ -121,17 +121,19 @@ class CBHG(nn.Module):
|
|||
for highway in self.highways:
|
||||
x = highway(x)
|
||||
|
||||
if input_lengths is not None:
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True)
|
||||
# if input_lengths is not None:
|
||||
# print(x.size())
|
||||
# print(len(input_lengths))
|
||||
# x = nn.utils.rnn.pack_padded_sequence(
|
||||
# x, input_lengths.data.cpu().numpy(), batch_first=True)
|
||||
|
||||
# (B, T_in, in_dim*2)
|
||||
self.gru.flatten_parameters()
|
||||
outputs, _ = self.gru(x)
|
||||
|
||||
if input_lengths is not None:
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
outputs, batch_first=True)
|
||||
#if input_lengths is not None:
|
||||
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
# outputs, batch_first=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -142,9 +144,9 @@ class Encoder(nn.Module):
|
|||
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||
|
||||
def forward(self, inputs, input_lengths=None):
|
||||
def forward(self, inputs):
|
||||
inputs = self.prenet(inputs)
|
||||
return self.cbhg(inputs, input_lengths)
|
||||
return self.cbhg(inputs)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
|
|
@ -30,7 +30,7 @@ class Tacotron(nn.Module):
|
|||
|
||||
inputs = self.embedding(characters)
|
||||
# (B, T', in_dim)
|
||||
encoder_outputs = self.encoder(inputs, input_lengths)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
|
||||
if self.use_memory_mask:
|
||||
memory_lengths = input_lengths
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
|||
import unittest
|
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
class TestDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
c.sample_rate,
|
||||
c.text_cleaner,
|
||||
c.num_mels,
|
||||
c.min_level_db,
|
||||
c.frame_shift_ms,
|
||||
c.frame_length_ms,
|
||||
c.preemphasis,
|
||||
c.ref_level_db,
|
||||
c.num_freq,
|
||||
c.power
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
print(text_lengths)
|
||||
magnitude_input = data[2]
|
||||
mel_input = data[3]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"sample_rate": 20000,
|
||||
"frame_length_ms": 50,
|
||||
"frame_shift_ms": 12.5,
|
||||
"preemphasis": 0.97,
|
||||
"min_level_db": -100,
|
||||
"ref_level_db": 20,
|
||||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 2000,
|
||||
"lr": 0.003,
|
||||
"lr_patience": 5,
|
||||
"lr_decay": 0.5,
|
||||
"batch_size": 8,
|
||||
"r": 5,
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
|
||||
"num_loader_workers": 4,
|
||||
|
||||
"save_step": 200,
|
||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"output_path": "result",
|
||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||
}
|
68
train.py
68
train.py
|
@ -22,6 +22,7 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
|||
create_experiment_folder, save_checkpoint,
|
||||
load_config, lr_decay)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
|
||||
|
@ -117,7 +118,6 @@ def main(args):
|
|||
|
||||
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||
# patience=c.lr_patience, verbose=True)
|
||||
|
||||
epoch_time = 0
|
||||
for epoch in range(c.epochs):
|
||||
|
||||
|
@ -133,6 +133,7 @@ def main(args):
|
|||
mel_input = data[3]
|
||||
|
||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||
print(current_step)
|
||||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step)
|
||||
|
@ -148,32 +149,28 @@ def main(args):
|
|||
#except:
|
||||
# raise TypeError("not same dimension")
|
||||
|
||||
if use_cuda:
|
||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||
torch.cuda.LongTensor)).cuda()
|
||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||
torch.cuda.LongTensor)).cuda()
|
||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||
torch.cuda.FloatTensor)).cuda()
|
||||
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
||||
torch.cuda.FloatTensor)).cuda()
|
||||
linear_spec_var = Variable(torch.from_numpy(magnitude_input)
|
||||
.type(torch.cuda.FloatTensor)).cuda()
|
||||
# convert inputs to variables
|
||||
text_input_var = Variable(text_input)
|
||||
mel_spec_var = Variable(mel_input)
|
||||
linear_spec_var = Variable(magnitude_input, volatile=True)
|
||||
|
||||
else:
|
||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||
torch.LongTensor),)
|
||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||
torch.LongTensor))
|
||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||
torch.FloatTensor))
|
||||
mel_spec_var = Variable(torch.from_numpy(
|
||||
mel_input).type(torch.FloatTensor))
|
||||
linear_spec_var = Variable(torch.from_numpy(
|
||||
magnitude_input).type(torch.FloatTensor))
|
||||
# sort sequence by length. Pytorch needs this.
|
||||
sorted_lengths, indices = torch.sort(
|
||||
text_lengths.view(-1), dim=0, descending=True)
|
||||
sorted_lengths = sorted_lengths.long().numpy()
|
||||
|
||||
text_input_var = text_input_var[indices]
|
||||
mel_spec_var = mel_spec_var[indices]
|
||||
linear_spec_var = linear_spec_var[indices]
|
||||
|
||||
if use_cuda:
|
||||
text_input_var = text_input_var.cuda()
|
||||
mel_spec_var = mel_spec_var.cuda()
|
||||
linear_spec_var = linear_spec_var.cuda()
|
||||
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var)
|
||||
model.forward(text_input_var, mel_spec_var,
|
||||
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
|
||||
|
||||
mel_loss = criterion(mel_output, mel_spec_var)
|
||||
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
||||
|
@ -198,7 +195,6 @@ def main(args):
|
|||
('mel_loss', mel_loss.data[0]),
|
||||
('grad_norm', grad_norm)])
|
||||
|
||||
|
||||
# Plot Learning Stats
|
||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
||||
|
@ -209,6 +205,9 @@ def main(args):
|
|||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||
|
||||
if current_step % c.save_step == 0:
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
|
@ -224,13 +223,24 @@ def main(args):
|
|||
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu()[None, :]
|
||||
gt_spec = linear_spec_var[0].data.cpu()[None, :]
|
||||
align_img = alignments[0].data.cpu().t()[None, :]
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
||||
|
||||
const_spec = plot_spectrogram(const_spec, dataset.ap)
|
||||
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
|
||||
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
||||
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||
|
||||
# Sample audio
|
||||
audio_signal = linear_output[0].data.cpu().numpy()
|
||||
dataset.ap.griffin_lim_iters = 60
|
||||
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
|
||||
#lr_scheduler.step(loss.data[0])
|
||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||
|
@ -240,7 +250,7 @@ def main(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--restore_step', type=int,
|
||||
help='Global step to restore checkpoint', default=128)
|
||||
help='Global step to restore checkpoint', default=0)
|
||||
parser.add_argument('--config_path', type=str,
|
||||
help='path to config file for training',)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def plot_alignment(alignment, info=None):
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
xlabel += '\n\n' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def plot_spectrogram(linear_output, audio):
|
||||
spectrogram = audio._denormalize(linear_output)
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
||||
plt.colorbar()
|
||||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
Loading…
Reference in New Issue