mirror of https://github.com/coqui-ai/TTS.git
A big revision: visualization, data loader, tests
This commit is contained in:
parent
86f3a56f59
commit
9f5e102473
Binary file not shown.
|
@ -56,7 +56,7 @@ class LJSpeechDataset(Dataset):
|
||||||
keys = list()
|
keys = list()
|
||||||
|
|
||||||
text = [d['text'] for d in batch]
|
text = [d['text'] for d in batch]
|
||||||
text_lenghts = [len(x) for x in text]
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
max_text_len = np.max(text_lenghts)
|
max_text_len = np.max(text_lenghts)
|
||||||
wav = [d['wav'] for d in batch]
|
wav = [d['wav'] for d in batch]
|
||||||
|
|
||||||
|
@ -77,6 +77,10 @@ class LJSpeechDataset(Dataset):
|
||||||
magnitude = magnitude.transpose(0, 2, 1)
|
magnitude = magnitude.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
magnitude = torch.FloatTensor(magnitude)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
return text, text_lenghts, magnitude, mel
|
return text, text_lenghts, magnitude, mel
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
|
|
|
@ -15,14 +15,14 @@
|
||||||
"lr": 0.01,
|
"lr": 0.01,
|
||||||
"lr_patience": 2,
|
"lr_patience": 2,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 16,
|
"batch_size": 32,
|
||||||
"griffinf_lim_iters": 60,
|
"griffinf_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"num_loader_workers": 16,
|
"num_loader_workers": 16,
|
||||||
|
|
||||||
"save_step": 1,
|
"save_step":1 ,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
|
|
@ -89,7 +89,7 @@ class CBHG(nn.Module):
|
||||||
self.gru = nn.GRU(
|
self.gru = nn.GRU(
|
||||||
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
def forward(self, inputs, input_lengths=None):
|
def forward(self, inputs):
|
||||||
# (B, T_in, in_dim)
|
# (B, T_in, in_dim)
|
||||||
x = inputs
|
x = inputs
|
||||||
|
|
||||||
|
@ -121,17 +121,19 @@ class CBHG(nn.Module):
|
||||||
for highway in self.highways:
|
for highway in self.highways:
|
||||||
x = highway(x)
|
x = highway(x)
|
||||||
|
|
||||||
if input_lengths is not None:
|
# if input_lengths is not None:
|
||||||
x = nn.utils.rnn.pack_padded_sequence(
|
# print(x.size())
|
||||||
x, input_lengths, batch_first=True)
|
# print(len(input_lengths))
|
||||||
|
# x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
# x, input_lengths.data.cpu().numpy(), batch_first=True)
|
||||||
|
|
||||||
# (B, T_in, in_dim*2)
|
# (B, T_in, in_dim*2)
|
||||||
self.gru.flatten_parameters()
|
self.gru.flatten_parameters()
|
||||||
outputs, _ = self.gru(x)
|
outputs, _ = self.gru(x)
|
||||||
|
|
||||||
if input_lengths is not None:
|
#if input_lengths is not None:
|
||||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
# outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
outputs, batch_first=True)
|
# outputs, batch_first=True)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -142,9 +144,9 @@ class Encoder(nn.Module):
|
||||||
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
||||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||||
|
|
||||||
def forward(self, inputs, input_lengths=None):
|
def forward(self, inputs):
|
||||||
inputs = self.prenet(inputs)
|
inputs = self.prenet(inputs)
|
||||||
return self.cbhg(inputs, input_lengths)
|
return self.cbhg(inputs)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
|
|
@ -30,7 +30,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# (B, T', in_dim)
|
# (B, T', in_dim)
|
||||||
encoder_outputs = self.encoder(inputs, input_lengths)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
if self.use_memory_mask:
|
if self.use_memory_mask:
|
||||||
memory_lengths = input_lengths
|
memory_lengths = input_lengths
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
||||||
|
import unittest
|
|
@ -0,0 +1,55 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from TTS.utils.generic_utils import load_config
|
||||||
|
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||||
|
|
||||||
|
|
||||||
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
class TestDataset(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TestDataset, self).__init__(*args, **kwargs)
|
||||||
|
self.max_loader_iter = 4
|
||||||
|
|
||||||
|
def test_loader(self):
|
||||||
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
|
os.path.join(c.data_path, 'wavs'),
|
||||||
|
c.r,
|
||||||
|
c.sample_rate,
|
||||||
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
print(text_lengths)
|
||||||
|
magnitude_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
|
||||||
|
neg_values = text_input[text_input < 0]
|
||||||
|
check_count = len(neg_values)
|
||||||
|
assert check_count == 0, \
|
||||||
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
{
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"sample_rate": 20000,
|
||||||
|
"frame_length_ms": 50,
|
||||||
|
"frame_shift_ms": 12.5,
|
||||||
|
"preemphasis": 0.97,
|
||||||
|
"min_level_db": -100,
|
||||||
|
"ref_level_db": 20,
|
||||||
|
"hidden_size": 128,
|
||||||
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
"epochs": 2000,
|
||||||
|
"lr": 0.003,
|
||||||
|
"lr_patience": 5,
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"batch_size": 8,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
|
"griffin_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
|
||||||
|
"num_loader_workers": 4,
|
||||||
|
|
||||||
|
"save_step": 200,
|
||||||
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
}
|
68
train.py
68
train.py
|
@ -22,6 +22,7 @@ from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
load_config, lr_decay)
|
load_config, lr_decay)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
|
||||||
|
@ -117,7 +118,6 @@ def main(args):
|
||||||
|
|
||||||
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||||
# patience=c.lr_patience, verbose=True)
|
# patience=c.lr_patience, verbose=True)
|
||||||
|
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(c.epochs):
|
||||||
|
|
||||||
|
@ -133,6 +133,7 @@ def main(args):
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
|
||||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||||
|
print(current_step)
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step)
|
current_lr = lr_decay(c.lr, current_step)
|
||||||
|
@ -148,32 +149,28 @@ def main(args):
|
||||||
#except:
|
#except:
|
||||||
# raise TypeError("not same dimension")
|
# raise TypeError("not same dimension")
|
||||||
|
|
||||||
if use_cuda:
|
# convert inputs to variables
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
text_input_var = Variable(text_input)
|
||||||
torch.cuda.LongTensor)).cuda()
|
mel_spec_var = Variable(mel_input)
|
||||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
linear_spec_var = Variable(magnitude_input, volatile=True)
|
||||||
torch.cuda.LongTensor)).cuda()
|
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
|
||||||
torch.cuda.FloatTensor)).cuda()
|
|
||||||
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
|
||||||
torch.cuda.FloatTensor)).cuda()
|
|
||||||
linear_spec_var = Variable(torch.from_numpy(magnitude_input)
|
|
||||||
.type(torch.cuda.FloatTensor)).cuda()
|
|
||||||
|
|
||||||
else:
|
# sort sequence by length. Pytorch needs this.
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
sorted_lengths, indices = torch.sort(
|
||||||
torch.LongTensor),)
|
text_lengths.view(-1), dim=0, descending=True)
|
||||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
sorted_lengths = sorted_lengths.long().numpy()
|
||||||
torch.LongTensor))
|
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
text_input_var = text_input_var[indices]
|
||||||
torch.FloatTensor))
|
mel_spec_var = mel_spec_var[indices]
|
||||||
mel_spec_var = Variable(torch.from_numpy(
|
linear_spec_var = linear_spec_var[indices]
|
||||||
mel_input).type(torch.FloatTensor))
|
|
||||||
linear_spec_var = Variable(torch.from_numpy(
|
if use_cuda:
|
||||||
magnitude_input).type(torch.FloatTensor))
|
text_input_var = text_input_var.cuda()
|
||||||
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var)
|
model.forward(text_input_var, mel_spec_var,
|
||||||
|
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
|
||||||
|
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
||||||
|
@ -198,7 +195,6 @@ def main(args):
|
||||||
('mel_loss', mel_loss.data[0]),
|
('mel_loss', mel_loss.data[0]),
|
||||||
('grad_norm', grad_norm)])
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||||
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
||||||
|
@ -209,6 +205,9 @@ def main(args):
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
tb.add_scalar('Time/StepTime', step_time, current_step)
|
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||||
|
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
|
@ -224,13 +223,24 @@ def main(args):
|
||||||
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu()[None, :]
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
gt_spec = linear_spec_var[0].data.cpu()[None, :]
|
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
||||||
align_img = alignments[0].data.cpu().t()[None, :]
|
|
||||||
|
const_spec = plot_spectrogram(const_spec, dataset.ap)
|
||||||
|
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
|
||||||
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
||||||
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
||||||
|
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
audio_signal = linear_output[0].data.cpu().numpy()
|
||||||
|
dataset.ap.griffin_lim_iters = 60
|
||||||
|
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
||||||
|
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
|
|
||||||
#lr_scheduler.step(loss.data[0])
|
#lr_scheduler.step(loss.data[0])
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
|
@ -240,7 +250,7 @@ def main(args):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--restore_step', type=int,
|
parser.add_argument('--restore_step', type=int,
|
||||||
help='Global step to restore checkpoint', default=128)
|
help='Global step to restore checkpoint', default=0)
|
||||||
parser.add_argument('--config_path', type=str,
|
parser.add_argument('--config_path', type=str,
|
||||||
help='path to config file for training',)
|
help='path to config file for training',)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_alignment(alignment, info=None):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
||||||
|
interpolation='none')
|
||||||
|
fig.colorbar(im, ax=ax)
|
||||||
|
xlabel = 'Decoder timestep'
|
||||||
|
if info is not None:
|
||||||
|
xlabel += '\n\n' + info
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel('Encoder timestep')
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(linear_output, audio):
|
||||||
|
spectrogram = audio._denormalize(linear_output)
|
||||||
|
fig = plt.figure(figsize=(16, 10))
|
||||||
|
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
||||||
|
plt.colorbar()
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
Loading…
Reference in New Issue